Repository: Yanfeng-Zhou/XNet Branch: main Commit: 58181894053f Files: 108 Total size: 1.1 MB Directory structure: gitextract_k1uwhl9g/ ├── .idea/ │ ├── XNet.iml │ ├── deployment.xml │ ├── inspectionProfiles/ │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ ├── vcs.xml │ └── workspace.xml ├── LICENSE ├── README.md ├── config/ │ ├── __init__.py │ ├── augmentation/ │ │ ├── __init__.py │ │ └── online_aug.py │ ├── dataset_config/ │ │ ├── __init__.py │ │ └── dataset_cfg.py │ ├── eval_config/ │ │ ├── __init__.py │ │ └── eval.py │ ├── ramps/ │ │ ├── __init__.py │ │ └── ramps.py │ ├── train_test_config/ │ │ ├── __init__.py │ │ └── train_test_config.py │ ├── visdom_config/ │ │ ├── __init__.py │ │ └── visual_visdom.py │ └── warmup_config/ │ ├── __init__.py │ └── warmup.py ├── dataload/ │ ├── __init__.py │ ├── dataset_2d.py │ └── dataset_3d.py ├── loss/ │ ├── __init__.py │ └── loss_function.py ├── models/ │ ├── __init__.py │ ├── getnetwork.py │ ├── networks_2d/ │ │ ├── __init__.py │ │ ├── aerial_lanenet.py │ │ ├── hrnet.py │ │ ├── mwcnn.py │ │ ├── resunet.py │ │ ├── resunet_plusplus.py │ │ ├── swinunet.py │ │ ├── u2net.py │ │ ├── unet.py │ │ ├── unet_3plus.py │ │ ├── unet_cct.py │ │ ├── unet_plusplus.py │ │ ├── unet_urpc.py │ │ ├── wavesnet.py │ │ ├── wds.py │ │ └── xnet.py │ └── networks_3d/ │ ├── __init__.py │ ├── conresnet.py │ ├── cotr.py │ ├── dmfnet.py │ ├── espnet3d.py │ ├── res_unet3d.py │ ├── transbts.py │ ├── unet3d.py │ ├── unet3d_cct.py │ ├── unet3d_dtc.py │ ├── unet3d_urpc.py │ ├── unetr.py │ ├── vnet.py │ ├── vnet_cct.py │ ├── vnet_dtc.py │ └── xnet3d.py ├── requirements.txt ├── test.py ├── test_3d.py ├── test_ConResNet.py ├── test_DTC.py ├── test_xnet.py ├── test_xnet3d.py ├── tools/ │ ├── Atrial/ │ │ ├── __init__.py │ │ ├── postprocess.py │ │ └── preprocess.py │ ├── LiTS/ │ │ ├── __init__.py │ │ ├── postprocess.py │ │ ├── preprocess.py │ │ └── split_train_val.py │ ├── __init__.py │ ├── eval.py │ ├── mask2sdf.py │ ├── res_image_mask.py │ ├── wavelet2D.py │ └── wavelet3D.py ├── train_semi_CCT.py ├── train_semi_CCT_3d.py ├── train_semi_CPS.py ├── train_semi_CPS_3d.py ├── train_semi_CT.py ├── train_semi_CT_3d.py ├── train_semi_DTC.py ├── train_semi_EM.py ├── train_semi_EM_3d.py ├── train_semi_MT.py ├── train_semi_MT_3d.py ├── train_semi_UAMT.py ├── train_semi_UAMT_3d.py ├── train_semi_URPC.py ├── train_semi_URPC_3d.py ├── train_semi_XNet.py ├── train_semi_XNet3d.py ├── train_sup.py ├── train_sup_3d.py ├── train_sup_ConResNet.py ├── train_sup_XNet.py ├── train_sup_XNet3d.py ├── train_sup_XNet_sb.py ├── train_sup_alnet.py └── train_sup_wds.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .idea/XNet.iml ================================================ ================================================ FILE: .idea/deployment.xml ================================================ ================================================ FILE: .idea/inspectionProfiles/profiles_settings.xml ================================================ ================================================ FILE: .idea/misc.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/vcs.xml ================================================ ================================================ FILE: .idea/workspace.xml ================================================ 1655015922020 ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Yanfeng Zhou Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================  # XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images This is the official code of [XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images](https://openaccess.thecvf.com/content/ICCV2023/html/Zhou_XNet_Wavelet-Based_Low_and_High_Frequency_Fusion_Networks_for_Fully-_ICCV_2023_paper.html) (ICCV 2023). ## Overview


Architecture of XNet.


Visualize dual-branch inputs. (a) Raw image. (b) Wavelet transform results. (c) Low frequency image. (d) High frequency image.


Architecture of LF and HF fusion module.

## Quantitative Comparison Comparison with fully- and semi-supervised state-of-the-art models on GlaS and CREMI test set. Semi-supervised models are based on UNet. DS indicates deep supervision. * indicates lightweight models. ‡ indicates training for 1000 epochs. - indicates training failed. **Red** and **bold** indicate the best and second best performance.

Comparison with fully- and semi-supervised state-of-the-art models on LA and LiTS test set. Due to GPU memory limitations, some semi-supervised models using smaller architectures, ✝ and * indicate models are based on lightweight 3D UNet (half of channels) and VNet, respectively. ‡ indicates training for 1000 epochs. - indicates training failed. **Red** and **bold** indicate the best and second best performance.

## Qualitative Comparison


Qualitative results on GIaS, CREMI, LA and LiTS. (a) Raw images. (b) Ground truth. (c) MT. (d) Semi-supervised XNet (3D XNet). (e) UNet (3D UNet). (f) Fully-Supervised XNet (3D XNet). The orange arrows highlight the difference among of the results.

## Reimplemented Architecture We have reimplemented some 2D and 3D models in semi- and supervised semantic segmentation.
Method DimensionModelCode
Supervised 2DUNetmodels/networks_2d/unet.py
UNet++models/networks_2d/unet_plusplus.py
Att-UNetmodels/networks_2d/unet.py
Aerial LaneNetmodels/networks_2d/aerial_lanenet.py
MWCNNmodels/networks_2d/mwcnn.py
HRNetmodels/networks_2d/hrnet.py
Res-UNetmodels/networks_2d/resunet.py
WDSmodels/networks_2d/wds.py
U2-Netmodels/networks_2d/u2net.py
UNet 3+models/networks_2d/unet_3plus.py
SwinUNetmodels/networks_2d/swinunet.py
WaveSNetmodels/networks_2d/wavesnet.py
XNet (Ours)models/networks_2d/xnet.py
3DVNetmodels/networks_3d/vnet.py
UNet 3Dmodels/networks_3d/unet3d.py
Res-UNet 3Dmodels/networks_3d/res_unet3d.py
ESPNet 3Dmodels/networks_3d/espnet3d.py
DMFNet 3Dmodels/networks_3d/dmfnet.py
ConResNetmodels/networks_3d/conresnet.py
CoTrmodels/networks_3d/cotr.py
TransBTSmodels/networks_3d/transbts.py
UNETRmodels/networks_3d/unetr.py
XNet 3D (Ours)models/networks_3d/xnet3d.py
Semi-Supervised 2DMTtrain_semi_MT.py
EMtrain_semi_EM.py
UAMTtrain_semi_UAMT.py
CCTtrain_semi_CCT.py
CPStrain_semi_CPS.py
URPCtrain_semi_URPC.py
CTtrain_semi_CT.py
XNet (Ours)train_semi_XNet.py
3DMTtrain_semi_MT_3d.py
EMtrain_semi_EM_3d.py
UAMTtrain_semi_UAMT_3d.py
CCTtrain_semi_CCT_3d.py
CPStrain_semi_CPS_3d.py
URPCtrain_semi_URPC_3d.py
CTtrain_semi_CT_3d.py
DTCtrain_semi_DTC.py
XNet 3D (Ours)train_semi_XNet3d.py
## Requirements ``` albumentations==0.5.2 einops==0.4.1 MedPy==0.4.0 numpy==1.20.2 opencv_python==4.2.0.34 opencv_python_headless==4.5.1.48 Pillow==8.0.0 PyWavelets==1.1.1 scikit_image==0.18.1 scikit_learn==1.0.1 scipy==1.4.1 SimpleITK==2.1.0 timm==0.6.7 torch==1.8.0+cu111 torchio==0.18.53 torchvision==0.9.0+cu111 tqdm==4.65.0 visdom==0.1.8.9 ``` ## Usage **Data preparation** Your datasets directory tree should be look like this: >to see [tools/wavelet2D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet2D.py) and [tools/wavelet3D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet3D.py) for **L** and **H** ``` dataset ├── train_sup_100 ├── L ├── 1.tif ├── 2.tif └── ... ├── H ├── 1.tif ├── 2.tif └── ... └── mask ├── 1.tif ├── 2.tif └── ... ├── train_sup_20 ├── L ├── H └── mask ├── train_unsup_80 └── L ├── H └── val ├── L ├── H └── mask ``` **Supervised training** ``` python -m torch.distributed.launch --nproc_per_node=4 train_sup_XNet.py ``` **Semi-supervised training** ``` python -m torch.distributed.launch --nproc_per_node=4 train_semi_XNet.py ``` **Testing** ``` python -m torch.distributed.launch --nproc_per_node=4 test.py ``` ## Citation If our work is useful for your research, please cite our paper: ``` @InProceedings{Zhou_2023_ICCV, author = {Zhou, Yanfeng and Huang, Jiaxing and Wang, Chenlong and Song, Le and Yang, Ge}, title = {XNet: Wavelet-Based Low and High Frequency Fusion Networks for Fully- and Semi-Supervised Semantic Segmentation of Biomedical Images}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, month = {October}, year = {2023}, pages = {21085-21096} } ``` ================================================ FILE: config/__init__.py ================================================ ================================================ FILE: config/augmentation/__init__.py ================================================ ================================================ FILE: config/augmentation/online_aug.py ================================================ import albumentations as A from albumentations.pytorch import ToTensorV2 from torchio import transforms as T import torchio as tio def data_transform_2d(): data_transforms = { 'train': A.Compose([ A.Resize(128, 128, p=1), A.Flip(p=0.75), A.Transpose(p=0.5), A.RandomRotate90(p=1), ], additional_targets={'image2': 'image', 'mask2': 'mask'} ), 'val': A.Compose([ A.Resize(128, 128, p=1), ], additional_targets={'image2': 'image', 'mask2': 'mask'} ), 'test': A.Compose([ A.Resize(128, 128, p=1), ], additional_targets={'image2': 'image', 'mask2': 'mask'} ) } return data_transforms def data_normalize_2d(mean, std): data_normalize = A.Compose([ A.Normalize(mean, std), ToTensorV2() ], additional_targets={'image2': 'image', 'mask2': 'mask'} ) return data_normalize def data_transform_aerial_lanenet(H, W): data_transforms = A.Compose([ A.Resize(H, W, p=1), ToTensorV2() ]) return data_transforms def data_transform_3d(normalization): data_transform = { 'train': T.Compose([ T.RandomFlip(), T.RandomBiasField(coefficients=(0.12, 0.15), order=2, p=0.2), T.OneOf({ T.RandomNoise(): 0.5, T.RandomBlur(std=1): 0.5, }, p=0.2), T.ZNormalization(masking_method=normalization), ]), 'val': T.Compose([ # T.CropOrPad(pad_size), T.ZNormalization(masking_method=normalization), # T.Resize(target_shape=(512, 512, 512), p=1) ]), 'test': T.Compose([ # T.CropOrPad(pad_size), T.ZNormalization(masking_method=normalization), # T.Resize(target_shape=(512, 512, 512), p=1) ]) } return data_transform ================================================ FILE: config/dataset_config/__init__.py ================================================ ================================================ FILE: config/dataset_config/dataset_cfg.py ================================================ import numpy as np import torchio as tio def dataset_cfg(dataet_name): config = { 'CREMI': { 'IN_CHANNELS': 1, 'NUM_CLASSES': 2, 'MEAN': [0.503902], 'STD': [0.110739], 'MEAN_DB2_H': [0.505787], 'STD_DB2_H': [0.115504], 'PALETTE': list(np.array([ [255, 255, 255], [0, 0, 0], ]).flatten()) }, 'GlaS': { 'IN_CHANNELS': 3, 'NUM_CLASSES': 2, 'MEAN': [0.787803, 0.512017, 0.784938], 'STD': [0.428206, 0.507778, 0.426366], 'MEAN_HAAR_H': [0.528318], 'STD_HAAR_H': [0.076766], 'MEAN_HAAR_L': [0.579144], 'STD_HAAR_L': [0.227451], 'MEAN_HAAR_HHL': [0.542428], 'STD_HAAR_HHL': [0.142663], 'MEAN_HAAR_HLL': [0.569150], 'STD_HAAR_HLL': [0.220854], 'MEAN_BIOR1.5_H': [0.525711], 'STD_BIOR1.5_H': [0.076606], 'MEAN_BIOR2.4_H': [0.516579], 'STD_BIOR2.4_H': [0.078798], 'MEAN_COIF1_H': [0.523858], 'STD_COIF1_H': [0.081001], 'MEAN_DB2_H': [0.505234], 'STD_DB2_H': [0.080919], 'MEAN_DMEY_H': [0.502698], 'STD_DMEY_H': [0.078861], 'PALETTE': list(np.array([ [0, 0, 0], [255, 255, 255], ]).flatten()) }, 'ISIC-2017': { 'IN_CHANNELS': 3, 'NUM_CLASSES': 2, 'MEAN': [0.699002, 0.556046, 0.512134], 'STD': [0.365650, 0.317347, 0.339400], 'MEAN_DB2_H': [0.489676], 'STD_DB2_H': [0.081749], 'PALETTE': list(np.array([ [0, 0, 0], [255, 255, 255], ]).flatten()) }, 'LiTS': { 'IN_CHANNELS': 1, 'NUM_CLASSES': 3, 'NORMALIZE': tio.ZNormalization.mean, 'PATCH_SIZE': (112, 112, 32), 'FORMAT': '.nii', 'NUM_SAMPLE_TRAIN': 8, 'NUM_SAMPLE_VAL': 12 }, 'Atrial': { 'IN_CHANNELS': 1, 'NUM_CLASSES': 2, 'NORMALIZE': tio.ZNormalization.mean, 'PATCH_SIZE': (96, 96, 80), 'FORMAT': '.nrrd', 'NUM_SAMPLE_TRAIN': 4, 'NUM_SAMPLE_VAL': 8 }, } return config[dataet_name] ================================================ FILE: config/eval_config/__init__.py ================================================ ================================================ FILE: config/eval_config/eval.py ================================================ import numpy as np from sklearn.metrics import confusion_matrix from scipy.spatial.distance import directed_hausdorff import torch def evaluate(y_scores, y_true, interval=0.02): y_scores = torch.softmax(y_scores, dim=1) y_scores = y_scores[:, 1, ...].cpu().detach().numpy().flatten() y_true = y_true.data.cpu().numpy().flatten() thresholds = np.arange(0, 0.9, interval) jaccard = np.zeros(len(thresholds)) dice = np.zeros(len(thresholds)) y_true.astype(np.int8) for indy in range(len(thresholds)): threshold = thresholds[indy] y_pred = (y_scores > threshold).astype(np.int8) sum_area = (y_pred + y_true) tp = float(np.sum(sum_area == 2)) union = np.sum(sum_area == 1) jaccard[indy] = tp / float(union + tp) dice[indy] = 2 * tp / float(union + 2 * tp) thred_indx = np.argmax(jaccard) m_jaccard = jaccard[thred_indx] m_dice = dice[thred_indx] return thresholds[thred_indx], m_jaccard, m_dice def evaluate_multi(y_scores, y_true): y_scores = torch.softmax(y_scores, dim=1) y_pred = torch.max(y_scores, 1)[1] y_pred = y_pred.data.cpu().numpy().flatten() y_true = y_true.data.cpu().numpy().flatten() hist = confusion_matrix(y_true, y_pred) hist_diag = np.diag(hist) hist_sum_0 = hist.sum(axis=0) hist_sum_1 = hist.sum(axis=1) jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag) m_jaccard = np.nanmean(jaccard) dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0) m_dice = np.nanmean(dice) return jaccard, m_jaccard, dice, m_dice ================================================ FILE: config/ramps/__init__.py ================================================ ================================================ FILE: config/ramps/ramps.py ================================================ import numpy as np def sigmoid_rampup(current, rampup_length): """Exponential rampup from https://arxiv.org/abs/1610.02242""" if rampup_length == 0: return 1.0 else: current = np.clip(current, 0.0, rampup_length) phase = 1.0 - current / rampup_length return float(np.exp(-5.0 * phase * phase)) def linear_rampup(current, rampup_length): """Linear rampup""" assert current >= 0 and rampup_length >= 0 if current >= rampup_length: return 1.0 else: return current / rampup_length def cosine_rampdown(current, rampdown_length): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" assert 0 <= current <= rampdown_length return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) ================================================ FILE: config/train_test_config/__init__.py ================================================ ================================================ FILE: config/train_test_config/train_test_config.py ================================================ import numpy as np from config.eval_config.eval import evaluate, evaluate_multi import torch import os from PIL import Image import torchio as tio def print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus): train_epoch_loss = train_loss / num_batches['train_sup'] print('-' * print_num) print('| Train Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|') print('-' * print_num) return train_epoch_loss def print_train_loss_MT(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_half, print_num_minus): train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup'] train_epoch_loss_cps = train_loss_cps / num_batches['train_sup'] train_epoch_loss = train_loss / num_batches['train_sup'] print('-' * print_num) print('| Train Sup Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '|') print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|') print('-' * print_num) return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss def print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus): train_epoch_loss_seg = train_loss_seg / num_batches['train_sup'] train_epoch_loss_res = train_loss_res / num_batches['train_sup'] train_epoch_loss = train_loss / num_batches['train_sup'] print('-' * print_num) print('| Train Seg Loss: {:.4f}'.format(train_epoch_loss_seg).ljust(print_num_half, ' '), '| Train Res Loss: {:.4f}'.format(train_epoch_loss_res).ljust(print_num_half, ' '), '|') print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|') print('-' * print_num) return train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss def print_train_loss_EM(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_minus): train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup'] train_epoch_loss_cps = train_loss_cps / num_batches['train_sup'] train_epoch_loss = train_loss / num_batches['train_sup'] print('-' * print_num) print('| Train Sup Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_minus, ' '), '|') print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_minus, ' '), '|') print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|') print('-' * print_num) return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss def print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_cps, train_loss, num_batches, print_num, print_num_half): train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup'] train_epoch_loss_sup2 = train_loss_sup_2 / num_batches['train_sup'] train_epoch_loss_cps = train_loss_cps / num_batches['train_sup'] train_epoch_loss = train_loss / num_batches['train_sup'] print('-' * print_num) print('| Train Sup Loss 1: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train SUP Loss 2: {:.4f}'.format(train_epoch_loss_sup2).ljust(print_num_half, ' '), '|') print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_half, ' '), '|') print('-' * print_num) return train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss def print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus): val_epoch_loss = val_loss / num_batches['val'] print('-' * print_num) print('| Val Loss: {:.4f}'.format(val_epoch_loss).ljust(print_num_minus, ' '), '|') print('-' * print_num) return val_epoch_loss def print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half): val_epoch_loss_sup1 = val_loss_sup_1 / num_batches['val'] val_epoch_loss_sup2 = val_loss_sup_2 / num_batches['val'] print('-' * print_num) print('| Val Sup Loss 1: {:.4f}'.format(val_epoch_loss_sup1).ljust(print_num_half, ' '), '| Val Sup Loss 2: {:.4f}'.format(val_epoch_loss_sup2).ljust(print_num_half, ' '), '|') print('-' * print_num) return val_epoch_loss_sup1, val_epoch_loss_sup2 def print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half): val_epoch_loss_seg = val_loss_seg / num_batches['val'] val_epoch_loss_res = val_loss_res / num_batches['val'] print('-' * print_num) print('| Val Seg Loss: {:.4f}'.format(val_epoch_loss_seg).ljust(print_num_half, ' '), '| Val Res Loss: {:.4f}'.format(val_epoch_loss_res).ljust(print_num_half, ' '), '|') print('-' * print_num) return val_epoch_loss_seg, val_epoch_loss_res def print_train_eval_sup(num_classes, score_list_train, mask_list_train, print_num): if num_classes == 2: eval_list = evaluate(score_list_train, mask_list_train) print('| Train Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Train Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Train Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|') train_m_jc = eval_list[1] else: eval_list = evaluate_multi(score_list_train, mask_list_train) np.set_printoptions(precision=4, suppress=True) print('| Train Jc: {}'.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Train Dc: {}'.format(eval_list[2]).ljust(print_num, ' '), '|') print('| Train mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Train mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|') train_m_jc = eval_list[1] return eval_list, train_m_jc def print_train_eval_XNet(num_classes, score_list_train1, score_list_train2, mask_list_train, print_num): if num_classes == 2: eval_list1 = evaluate(score_list_train1, mask_list_train) eval_list2 = evaluate(score_list_train2, mask_list_train) print('| Train Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|') print('| Train Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|') print('| Train Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|') train_m_jc1 = eval_list1[1] train_m_jc2 = eval_list2[1] else: eval_list1 = evaluate_multi(score_list_train1, mask_list_train) eval_list2 = evaluate_multi(score_list_train2, mask_list_train) np.set_printoptions(precision=4, suppress=True) print('| Train Jc 1: {}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|') print('| Train Dc 1: {}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|') print('| Train mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|') print('| Train mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Train mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|') train_m_jc1 = eval_list1[1] train_m_jc2 = eval_list2[1] return eval_list1, eval_list2, train_m_jc1, train_m_jc2 def print_val_eval_sup(num_classes, score_list_val, mask_list_val, print_num): if num_classes == 2: eval_list = evaluate(score_list_val, mask_list_val) print('| Val Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Val Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Val Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|') val_m_jc = eval_list[1] else: eval_list = evaluate_multi(score_list_val, mask_list_val) np.set_printoptions(precision=4, suppress=True) print('| Val Jc: {} '.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Val Dc: {} '.format(eval_list[2]).ljust(print_num, ' '), '|') print('| Val mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Val mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|') val_m_jc = eval_list[1] return eval_list, val_m_jc def print_val_eval(num_classes, score_list_val1, score_list_val2, mask_list_val, print_num): if num_classes == 2: eval_list1 = evaluate(score_list_val1, mask_list_val) eval_list2 = evaluate(score_list_val2, mask_list_val) print('| Val Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Val Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|') print('| Val Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|') print('| Val Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Val Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|') val_m_jc1 = eval_list1[1] val_m_jc2 = eval_list2[1] else: eval_list1 = evaluate_multi(score_list_val1, mask_list_val) eval_list2 = evaluate_multi(score_list_val2, mask_list_val) np.set_printoptions(precision=4, suppress=True) print('| Val Jc 1: {} '.format(eval_list1[0]).ljust(print_num, ' '), '| Val Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|') print('| Val Dc 1: {} '.format(eval_list1[2]).ljust(print_num, ' '), '| Val Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|') print('| Val mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|') print('| Val mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Val mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|') val_m_jc1 = eval_list1[1] val_m_jc2 = eval_list2[1] return eval_list1, eval_list2, val_m_jc1, val_m_jc2 def save_val_best_sup_2d(num_classes, best_list, model, score_list_val, name_list_val, eval_list, path_trained_model, path_seg_results, palette, model_name): if num_classes == 2: if best_list[1] < eval_list[1]: best_list = eval_list torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1]))) score_list_val = torch.softmax(score_list_val, dim=1) pred_results = score_list_val[:, 1, :, :].cpu().numpy() pred_results[pred_results > eval_list[0]] = 1 pred_results[pred_results <= eval_list[0]] = 0 assert len(name_list_val) == pred_results.shape[0] for i in range(len(name_list_val)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_val[i])) else: if best_list[1] < eval_list[1]: best_list = eval_list torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1]))) pred_results = torch.max(score_list_val, 1)[1] pred_results = pred_results.cpu().numpy() assert len(name_list_val) == pred_results.shape[0] for i in range(len(name_list_val)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_val[i])) return best_list def save_val_best_sup_3d(num_classes, best_list, model, score_list_val, mask_list_val, eval_list, path_trained_model, path_seg_results, path_mask_results, model_name, format): if num_classes == 2: if best_list[1] < eval_list[1]: best_list = eval_list torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1]))) else: if best_list[1] < eval_list[1]: best_list = eval_list torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1]))) return best_list def save_val_best_2d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, name_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, palette): if eval_list_1[1] < eval_list_2[1]: if best_list[1] < eval_list_2[1]: best_model = model2 best_list = eval_list_2 best_result = 'Result2' torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1]))) if num_classes == 2: score_list_val_2 = torch.softmax(score_list_val_2, dim=1) pred_results = score_list_val_2[:, 1, ...].cpu().numpy() pred_results[pred_results > eval_list_2[0]] = 1 pred_results[pred_results <= eval_list_2[0]] = 0 else: pred_results = torch.max(score_list_val_2, 1)[1] pred_results = pred_results.cpu().numpy() assert len(name_list_val) == pred_results.shape[0] for i in range(len(name_list_val)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_val[i])) else: best_model = best_model best_list = best_list best_result = best_result else: if best_list[1] < eval_list_1[1]: best_model = model1 best_list = eval_list_1 best_result = 'Result1' torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1]))) if num_classes == 2: score_list_val_1 = torch.softmax(score_list_val_1, dim=1) pred_results = score_list_val_1[:, 1, ...].cpu().numpy() pred_results[pred_results > eval_list_1[0]] = 1 pred_results[pred_results <= eval_list_1[0]] = 0 else: pred_results = torch.max(score_list_val_1, 1)[1] pred_results = pred_results.cpu().numpy() assert len(name_list_val) == pred_results.shape[0] for i in range(len(name_list_val)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_val[i])) else: best_model = best_model best_list = best_list best_result = best_result return best_list, best_model, best_result def save_val_best_3d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, path_mask_results, format): if eval_list_1[1] < eval_list_2[1]: if best_list[1] < eval_list_2[1]: best_model = model2 best_list = eval_list_2 best_result = 'Result2' torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1]))) else: best_model = best_model best_list = best_list best_result = best_result else: if best_list[1] < eval_list_1[1]: best_model = model1 best_list = eval_list_1 best_result = 'Result1' torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1]))) else: best_model = best_model best_list = best_list best_result = best_result return best_list, best_model, best_result def draw_pred_sup(num_classes, mask_train_sup, mask_val, pred_train_sup, outputs_val, train_eval_list, val_eval_list): mask_image_train_sup = mask_train_sup[0, :, :].data.cpu().numpy() mask_image_val = mask_val[0, :, :].data.cpu().numpy() if num_classes == 2: pred_image_train_sup = pred_train_sup[0, 1, :, :].data.cpu().numpy() pred_image_train_sup[pred_image_train_sup > train_eval_list[0]] = 1 pred_image_train_sup[pred_image_train_sup <= train_eval_list[0]] = 0 pred_image_val = outputs_val[0, 1, :, :].data.cpu().numpy() pred_image_val[pred_image_val > val_eval_list[0]] = 1 pred_image_val[pred_image_val <= val_eval_list[0]] = 0 else: pred_image_train_sup = torch.max(pred_train_sup, 1)[1] pred_image_train_sup = pred_image_train_sup[0, :, :].cpu().numpy() pred_image_val = torch.max(outputs_val, 1)[1] pred_image_val = pred_image_val[0, :, :].cpu().numpy() return mask_image_train_sup, pred_image_train_sup, mask_image_val, pred_image_val def draw_pred_XNet(num_classes, mask_train, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2): mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy() mask_image_val = mask_val[0, :, :].data.cpu().numpy() if num_classes == 2: pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy() pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1 pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0 pred_image_train_sup2 = pred_train_sup2[0, 1, :, :].data.cpu().numpy() pred_image_train_sup2[pred_image_train_sup2 > train_eval_list2[0]] = 1 pred_image_train_sup2[pred_image_train_sup2 <= train_eval_list2[0]] = 0 pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy() pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1 pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0 pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy() pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1 pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0 else: pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1] pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy() pred_image_train_sup2 = torch.max(pred_train_sup2, 1)[1] pred_image_train_sup2 = pred_image_train_sup2[0, :, :].cpu().numpy() pred_image_val1 = torch.max(outputs_val1, 1)[1] pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy() pred_image_val2 = torch.max(outputs_val2, 1)[1] pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy() return mask_image_train_sup, pred_image_train_sup1, pred_image_train_sup2, mask_image_val, pred_image_val1, pred_image_val2 def draw_pred_MT(num_classes, mask_train, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2): mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy() mask_image_val = mask_val[0, :, :].data.cpu().numpy() if num_classes == 2: pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy() pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1 pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0 pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy() pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1 pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0 pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy() pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1 pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0 else: pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1] pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy() pred_image_val1 = torch.max(outputs_val1, 1)[1] pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy() pred_image_val2 = torch.max(outputs_val2, 1)[1] pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy() return mask_image_train_sup, pred_image_train_sup1, mask_image_val, pred_image_val1, pred_image_val2 def print_best_sup(num_classes, best_val_list, print_num): if num_classes == 2: print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|') print('| Best Val Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|') print('| Best Val Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|') else: np.set_printoptions(precision=4, suppress=True) print('| Best Val Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|') print('| Best Val Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|') print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|') print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|') def print_best(num_classes, best_val_list, best_model, best_result, path_trained_model, print_num): if num_classes == 2: torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1]))) print('| Best Result: {}'.format(best_result).ljust(print_num, ' '), '|') print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|') print('| Best Val Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|') print('| Best Val Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|') else: torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1]))) np.set_printoptions(precision=4, suppress=True) print('| Best Result: {}'.format(best_result).ljust(print_num, ' '), '|') print('| Best Val Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|') print('| Best Val Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|') print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|') print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|') def print_test_eval(num_classes, score_list_test, mask_list_test, print_num): if num_classes == 2: eval_list = evaluate(score_list_test, mask_list_test) print('| Test Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Test Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Test Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|') else: eval_list = evaluate_multi(score_list_test, mask_list_test) np.set_printoptions(precision=4, suppress=True) print('| Test Jc: {} '.format(eval_list[0]).ljust(print_num, ' '), '|') print('| Test Dc: {} '.format(eval_list[2]).ljust(print_num, ' '), '|') print('| Test mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|') print('| Test mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|') return eval_list def save_test_2d(num_classes, score_list_test, name_list_test, threshold, path_seg_results, palette): if num_classes == 2: score_list_test = torch.softmax(score_list_test, dim=1) pred_results = score_list_test[:, 1, ...].cpu().numpy() pred_results[pred_results > threshold] = 1 pred_results[pred_results <= threshold] = 0 assert len(name_list_test) == pred_results.shape[0] for i in range(len(name_list_test)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_test[i])) else: pred_results = torch.max(score_list_test, 1)[1] pred_results = pred_results.cpu().numpy() assert len(name_list_test) == pred_results.shape[0] for i in range(len(name_list_test)): color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P') color_results.putpalette(palette) color_results.save(os.path.join(path_seg_results, name_list_test[i])) def save_test_3d(num_classes, score_test, name_test, threshold, path_seg_results, affine): if num_classes == 2: score_list_test = torch.softmax(score_test, dim=0) pred_results = score_list_test[1, ...].cpu() pred_results[pred_results > threshold] = 1 pred_results[pred_results <= threshold] = 0 pred_results = pred_results.type(torch.uint8) output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine) output_image.save(os.path.join(path_seg_results, name_test)) else: pred_results = torch.max(score_test, 0)[1] pred_results = pred_results.cpu() pred_results = pred_results.type(torch.uint8) output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine) output_image.save(os.path.join(path_seg_results, name_test)) ================================================ FILE: config/visdom_config/__init__.py ================================================ ================================================ FILE: config/visdom_config/visual_visdom.py ================================================ from visdom import Visdom import os def visdom_initialization_sup(env, port): visdom = Visdom(env=env, port=port) visdom.line([0.], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss'], width=550, height=350)) visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350)) visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Loss'], width=550, height=350)) visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350)) return visdom def visualization_sup(vis, epoch, train_loss, train_m_jc, val_loss, val_m_jc): vis.line([train_loss], [epoch], win='train_loss', update='append') vis.line([train_m_jc], [epoch], win='train_jc', update='append') vis.line([val_loss], [epoch], win='val_loss', update='append') vis.line([val_m_jc], [epoch], win='val_jc', update='append') def visual_image_sup(vis, mask_train, pred_train, mask_val, pred_val): vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis')) vis.heatmap(pred_train, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis')) vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis')) vis.heatmap(pred_val, win='val_pred1', opts=dict(title='Val Pred', colormap='Viridis')) def visdom_initialization_XNet(env, port): visdom = Visdom(env=env, port=port) visdom.line([[0., 0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup1', 'Train Sup2', 'Train Unsup'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc1', 'Train Jc2'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350)) return visdom def visualization_XNet(vis, epoch, train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps, train_m_jc1, train_m_jc2, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2): vis.line([[train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps]], [epoch], win='train_loss', update='append') vis.line([[train_m_jc1, train_m_jc2]], [epoch], win='train_jc', update='append') vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append') vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append') def visual_image_XNet(vis, mask_train, pred_train1, pred_train2, mask_val, pred_val1, pred_val2): vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis')) vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred1', colormap='Viridis')) vis.heatmap(pred_train2, win='train_pred2', opts=dict(title='Train pred2', colormap='Viridis')) vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis')) vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis')) vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis')) def visdom_initialization_MT(env, port): visdom = Visdom(env=env, port=port) visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350)) visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350)) return visdom def visualization_MT(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2): vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append') vis.line([train_m_jc1], [epoch], win='train_jc', update='append') vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append') vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append') def visual_image_MT(vis, mask_train, pred_train1, mask_val, pred_val1, pred_val2): vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis')) vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis')) vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis')) vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis')) vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis')) def visdom_initialization_EM(env, port): visdom = Visdom(env=env, port=port) visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350)) visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350)) visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup'], width=550, height=350)) visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350)) return visdom def visualization_EM(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_m_jc1): vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append') vis.line([train_m_jc1], [epoch], win='train_jc', update='append') vis.line([val_loss_sup1], [epoch], win='val_loss', update='append') vis.line([val_m_jc1], [epoch], win='val_jc', update='append') def visdom_initialization_ConResNet(env, port): visdom = Visdom(env=env, port=port) visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Seg', 'Train Res'], width=550, height=350)) visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350)) visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Seg', 'Val Res'], width=550, height=350)) visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350)) return visdom def visualization_ConResNet(vis, epoch, train_loss, train_loss_seg, train_loss_res, train_m_jc1, val_loss_seg, val_loss_res, val_m_jc1): vis.line([[train_loss, train_loss_seg, train_loss_res]], [epoch], win='train_loss', update='append') vis.line([train_m_jc1], [epoch], win='train_jc', update='append') vis.line([[val_loss_seg, val_loss_res]], [epoch], win='val_loss', update='append') vis.line([val_m_jc1], [epoch], win='val_jc', update='append') ================================================ FILE: config/warmup_config/__init__.py ================================================ ================================================ FILE: config/warmup_config/warmup.py ================================================ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ReduceLROnPlateau class GradualWarmupScheduler(_LRScheduler): """ Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. Args: optimizer (Optimizer): Wrapped optimizer. multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. total_epoch: target learning rate is reached at total_epoch, gradually after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): self.multiplier = multiplier if self.multiplier < 1.: raise ValueError('multiplier should be greater thant or equal to 1.') self.total_epoch = total_epoch self.after_scheduler = after_scheduler self.finished = False super(GradualWarmupScheduler, self).__init__(optimizer) def get_lr(self): if self.last_epoch > self.total_epoch: if self.after_scheduler: if not self.finished: self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True return self.after_scheduler.get_last_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] if self.multiplier == 1.0: return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] def step_ReduceLROnPlateau(self, metrics, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning if self.last_epoch <= self.total_epoch: warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): param_group['lr'] = lr else: if epoch is None: self.after_scheduler.step(metrics, None) else: self.after_scheduler.step(metrics, epoch - self.total_epoch) def step(self, epoch=None, metrics=None): if type(self.after_scheduler) != ReduceLROnPlateau: if self.finished and self.after_scheduler: if epoch is None: self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.total_epoch) self._last_lr = self.after_scheduler.get_last_lr() else: return super(GradualWarmupScheduler, self).step(epoch) else: self.step_ReduceLROnPlateau(metrics, epoch) ================================================ FILE: dataload/__init__.py ================================================ ================================================ FILE: dataload/dataset_2d.py ================================================ import os import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import cv2 import numpy as np import pywt class dataset_itn(Dataset): def __init__(self, data_dir, input1, augmentation_1, normalize_1, sup=True, num_images=None, **kwargs): super(dataset_itn, self).__init__() img_paths_1 = [] mask_paths = [] image_dir_1 = data_dir + '/' + input1 if sup: mask_dir = data_dir + '/mask' for image in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, image) img_paths_1.append(image_path_1) if sup: mask_path = os.path.join(mask_dir, image) mask_paths.append(mask_path) if sup: assert len(img_paths_1) == len(mask_paths) if num_images is not None: len_img_paths = len(img_paths_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: img_paths_1 = img_paths_1[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] img_paths_1 = img_paths_1 * quotient img_paths_1 += [img_paths_1[i] for i in new_indices] if sup: mask_paths = mask_paths * quotient mask_paths += [mask_paths[i] for i in new_indices] self.img_paths_1 = img_paths_1 self.mask_paths = mask_paths self.augmentation_1 = augmentation_1 self.normalize_1 = normalize_1 self.sup = sup self.kwargs = kwargs def __getitem__(self, index): img_path_1 = self.img_paths_1[index] img_1 = Image.open(img_path_1) img_1 = np.array(img_1) if self.sup: mask_path = self.mask_paths[index] mask = Image.open(mask_path) mask = np.array(mask) augment_1 = self.augmentation_1(image=img_1, mask=mask) img_1 = augment_1['image'] mask_1 = augment_1['mask'] normalize_1 = self.normalize_1(image=img_1, mask=mask_1) img_1 = normalize_1['image'] mask_1 = normalize_1['mask'] mask_1 = mask_1.long() sampel = {'image': img_1, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]} else: augment_1 = self.augmentation_1(image=img_1) img_1 = augment_1['image'] normalize_1 = self.normalize_1(image=img_1) img_1 = normalize_1['image'] sampel = {'image': img_1, 'ID': os.path.split(img_path_1)[1]} return sampel def __len__(self): return len(self.img_paths_1) def imagefloder_itn(data_dir, input1, data_transform_1, data_normalize_1, sup=True, num_images=None, **kwargs): dataset = dataset_itn(data_dir=data_dir, input1=input1, augmentation_1=data_transform_1, normalize_1=data_normalize_1, sup=sup, num_images=num_images, **kwargs ) return dataset class dataset_iitnn(Dataset): def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True, num_images=None, **kwargs): super(dataset_iitnn, self).__init__() img_paths_1 = [] img_paths_2 = [] mask_paths = [] image_dir_1 = data_dir + '/' + input1 image_dir_2 = data_dir + '/' + input2 if sup: mask_dir = data_dir + '/mask' for image in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, image) img_paths_1.append(image_path_1) image_path_2 = os.path.join(image_dir_2, image) img_paths_2.append(image_path_2) if sup: mask_path = os.path.join(mask_dir, image) mask_paths.append(mask_path) assert len(img_paths_1) == len(img_paths_2) if sup: assert len(img_paths_1) == len(mask_paths) if num_images is not None: len_img_paths = len(img_paths_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: img_paths_1 = img_paths_1[:num_images] img_paths_2 = img_paths_2[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] img_paths_1 = img_paths_1 * quotient img_paths_1 += [img_paths_1[i] for i in new_indices] img_paths_2 = img_paths_2 * quotient img_paths_2 += [img_paths_2[i] for i in new_indices] if sup: mask_paths = mask_paths * quotient mask_paths += [mask_paths[i] for i in new_indices] self.img_paths_1 = img_paths_1 self.img_paths_2 = img_paths_2 self.mask_paths = mask_paths self.augmentation_1 = augmentation1 self.normalize_1 = normalize_1 self.normalize_2 = normalize_2 self.sup = sup self.kwargs = kwargs def __getitem__(self, index): img_path_1 = self.img_paths_1[index] img_1 = Image.open(img_path_1) img_1 = np.array(img_1) img_path_2 = self.img_paths_2[index] img_2 = Image.open(img_path_2) img_2 = np.array(img_2) if self.sup: mask_path = self.mask_paths[index] mask = Image.open(mask_path) mask = np.array(mask) augment_1 = self.augmentation_1(image=img_1, image2=img_2, mask=mask) img_1 = augment_1['image'] img_2 = augment_1['image2'] mask = augment_1['mask'] normalize_1 = self.normalize_1(image=img_1, mask=mask) img_1 = normalize_1['image'] mask = normalize_1['mask'] mask = mask.long() normalize_2 = self.normalize_2(image=img_2) img_2 = normalize_2['image'] sampel = {'image': img_1, 'image_2': img_2, 'mask': mask, 'ID': os.path.split(mask_path)[1]} else: augment_1 = self.augmentation_1(image=img_1, image2=img_2) img_1 = augment_1['image'] img_2 = augment_1['image2'] normalize_1 = self.normalize_1(image=img_1) img_1 = normalize_1['image'] normalize_2 = self.normalize_2(image=img_2) img_2 = normalize_2['image'] sampel = {'image': img_1, 'image_2': img_2, 'ID': os.path.split(img_path_1)[1]} return sampel def __len__(self): return len(self.img_paths_1) def imagefloder_iitnn(data_dir, input1, input2, data_transform_1, data_normalize_1, data_normalize_2, sup=True, num_images=None, **kwargs): dataset = dataset_iitnn(data_dir=data_dir, input1=input1, input2=input2, augmentation1=data_transform_1, normalize_1=data_normalize_1, normalize_2=data_normalize_2, sup=sup, num_images=num_images, **kwargs ) return dataset class dataset_wds(Dataset): def __init__(self, data_dir, augmentation1, normalize_LL, normalize_LH, normalize_HL, normalize_HH, **kwargs): super(dataset_wds, self).__init__() img_paths_LL = [] img_paths_LH = [] img_paths_HL = [] img_paths_HH = [] mask_paths = [] image_dir_LL = data_dir + '/LL' image_dir_LH = data_dir + '/LH' image_dir_HL = data_dir + '/HL' image_dir_HH = data_dir + '/HH' mask_dir = data_dir + '/mask' for image in os.listdir(image_dir_LL): image_path_LL = os.path.join(image_dir_LL, image) img_paths_LL.append(image_path_LL) image_path_LH = os.path.join(image_dir_LH, image) img_paths_LH.append(image_path_LH) image_path_HL = os.path.join(image_dir_HL, image) img_paths_HL.append(image_path_HL) image_path_HH = os.path.join(image_dir_HH, image) img_paths_HH.append(image_path_HH) mask_path = os.path.join(mask_dir, image) mask_paths.append(mask_path) self.img_paths_LL = img_paths_LL self.img_paths_LH = img_paths_LH self.img_paths_HL = img_paths_HL self.img_paths_HH = img_paths_HH self.mask_paths = mask_paths self.augmentation_1 = augmentation1 self.normalize_LL = normalize_LL self.normalize_LH = normalize_LH self.normalize_HL = normalize_HL self.normalize_HH = normalize_HH self.kwargs = kwargs def __getitem__(self, index): img_path_LL = self.img_paths_LL[index] img_LL = Image.open(img_path_LL) img_LL = np.array(img_LL) img_path_LH = self.img_paths_LH[index] img_LH = Image.open(img_path_LH) img_LH = np.array(img_LH) img_path_HL = self.img_paths_HL[index] img_HL = Image.open(img_path_HL) img_HL = np.array(img_HL) img_path_HH = self.img_paths_HH[index] img_HH = Image.open(img_path_HH) img_HH = np.array(img_HH) mask_path = self.mask_paths[index] mask = Image.open(mask_path) mask = np.array(mask) augment_1 = self.augmentation_1(image=img_LL, mask=mask, imageLH=img_LH, imageHL=img_HL, imageHH=img_HH) img_LL = augment_1['image'] img_LH = augment_1['imageLH'] img_HL = augment_1['imageHL'] img_HH = augment_1['imageHH'] mask_1 = augment_1['mask'] normalize_LL = self.normalize_LL(image=img_LL, mask=mask_1) img_LL = normalize_LL['image'] mask_1 = normalize_LL['mask'] mask_1 = mask_1.long() normalize_LH = self.normalize_LH(image=img_LH) img_LH = normalize_LH['image'] normalize_HL = self.normalize_HL(image=img_HL) img_HL = normalize_HL['image'] normalize_HH = self.normalize_HH(image=img_HH) img_HH = normalize_HH['image'] sampel = {'image_LL': img_LL, 'image_LH': img_LH, 'image_HL': img_HL, 'image_HH': img_HH, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]} return sampel def __len__(self): return len(self.img_paths_LL) def imagefloder_wds(data_dir, data_transform_1, data_normalize_LL, data_normalize_LH, data_normalize_HL, data_normalize_HH, **kwargs): dataset = dataset_wds(data_dir=data_dir, augmentation1=data_transform_1, normalize_LL=data_normalize_LL, normalize_LH=data_normalize_LH, normalize_HL=data_normalize_HL, normalize_HH=data_normalize_HH, **kwargs ) return dataset class dataset_aerial_lanenet(Dataset): def __init__(self, data_dir, augmentation1, normalize_1, normalize_l1, normalize_l2, normalize_l3, normalize_l4, **kwargs): super(dataset_aerial_lanenet, self).__init__() img_paths = [] mask_paths = [] image_dir = data_dir + '/image' mask_dir = data_dir + '/mask' for image in os.listdir(image_dir): image_path = os.path.join(image_dir, image) img_paths.append(image_path) mask_path = os.path.join(mask_dir, image) mask_paths.append(mask_path) self.img_paths = img_paths self.mask_paths = mask_paths self.augmentation_1 = augmentation1 self.normalize_1 = normalize_1 self.normalize_l4 = normalize_l4 self.normalize_l3 = normalize_l3 self.normalize_l2 = normalize_l2 self.normalize_l1 = normalize_l1 self.kwargs = kwargs def __getitem__(self, index): img_path = self.img_paths[index] img = Image.open(img_path) img = np.array(img) mask_path = self.mask_paths[index] mask = Image.open(mask_path) mask = np.array(mask) augment_1 = self.augmentation_1(image=img, mask=mask) img = augment_1['image'] mask = augment_1['mask'] img_ = np.array(Image.fromarray(img).convert('L')) _, l4, l3, l2, l1 = pywt.wavedec2(img_, 'db2', level=4) l4 = np.array(l4).transpose(1, 2, 0) l3 = np.array(l3).transpose(1, 2, 0) l2 = np.array(l2).transpose(1, 2, 0) l1 = np.array(l1).transpose(1, 2, 0) normalize_l4 = self.normalize_l4(image=l4) l4 = normalize_l4['image'].float() normalize_l3 = self.normalize_l3(image=l3) l3 = normalize_l3['image'].float() normalize_l2 = self.normalize_l2(image=l2) l2 = normalize_l2['image'].float() normalize_l1 = self.normalize_l1(image=l1) l1 = normalize_l1['image'].float() normalize_1 = self.normalize_1(image=img, mask=mask) img = normalize_1['image'] mask = normalize_1['mask'].long() sampel = {'image': img, 'image_l1': l1, 'image_l2': l2, 'image_l3': l3, 'image_l4': l4, 'mask': mask, 'ID': os.path.split(mask_path)[1]} return sampel def __len__(self): return len(self.img_paths) def imagefloder_aerial_lanenet(data_dir, data_transform, data_normalize, data_normalize_l1, data_normalize_l2, data_normalize_l3, data_normalize_l4, **kwargs): dataset = dataset_aerial_lanenet(data_dir=data_dir, augmentation1=data_transform, normalize_1=data_normalize, normalize_l1=data_normalize_l1, normalize_l2=data_normalize_l2, normalize_l3=data_normalize_l3, normalize_l4=data_normalize_l4, **kwargs ) return dataset ================================================ FILE: dataload/dataset_3d.py ================================================ import os import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import cv2 import numpy as np import torchio as tio import SimpleITK as sitk from torchio.data import UniformSampler, LabelSampler class dataset_it(Dataset): def __init__(self, data_dir, input1, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None): super(dataset_it, self).__init__() self.subjects_1 = [] image_dir_1 = data_dir + '/' + input1 if sup: mask_dir = data_dir + '/mask' for i in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, i) if sup: mask_path = os.path.join(mask_dir, i) subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path), ID=i) else: subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i) self.subjects_1.append(subject_1) if num_images is not None: len_img_paths = len(self.subjects_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: self.subjects_1 = self.subjects_1[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] self.subjects_1 = self.subjects_1 * quotient self.subjects_1 += [self.subjects_1[i] for i in new_indices] self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1) self.queue_train_set_1 = tio.Queue( subjects_dataset=self.dataset_1, max_length=queue_length, samples_per_volume=samples_per_volume, sampler=UniformSampler(patch_size), # sampler=LabelSampler(patch_size), num_workers=num_workers, shuffle_subjects=shuffle_subjects, shuffle_patches=shuffle_patches ) class dataset_it_dtc(Dataset): def __init__(self, data_dir, input1, num_classes, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None): super(dataset_it_dtc, self).__init__() self.subjects_1 = [] image_dir_1 = data_dir + '/' + input1 if sup: mask_dir_1 = data_dir + '/mask' mask_dir_2 = data_dir + '/mask_sdf1' if num_classes == 3: mask_dir_3 = data_dir + '/mask_sdf2' for i in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, i) if sup: mask_path_1 = os.path.join(mask_dir_1, i) mask_path_2 = os.path.join(mask_dir_2, i) if num_classes == 3: mask_path_3 = os.path.join(mask_dir_3, i) subject_1 = tio.Subject( image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), mask3=tio.LabelMap(mask_path_3), ID=i) else: subject_1 = tio.Subject( image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), ID=i) else: subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i) self.subjects_1.append(subject_1) if num_images is not None: len_img_paths = len(self.subjects_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: self.subjects_1 = self.subjects_1[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] self.subjects_1 = self.subjects_1 * quotient self.subjects_1 += [self.subjects_1[i] for i in new_indices] self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1) self.queue_train_set_1 = tio.Queue( subjects_dataset=self.dataset_1, max_length=queue_length, samples_per_volume=samples_per_volume, sampler=UniformSampler(patch_size), # sampler=LabelSampler(patch_size), num_workers=num_workers, shuffle_subjects=shuffle_subjects, shuffle_patches=shuffle_patches ) class dataset_iit(Dataset): def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None): super(dataset_iit, self).__init__() self.subjects_1 = [] image_dir_1 = data_dir + '/' + input1 image_dir_2 = data_dir + '/' + input2 if sup: mask_dir_1 = data_dir + '/mask' for i in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, i) image_path_2 = os.path.join(image_dir_2, i) if sup: mask_path_1 = os.path.join(mask_dir_1, i) subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), ID=i) else: subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i) self.subjects_1.append(subject_1) if num_images is not None: len_img_paths = len(self.subjects_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: self.subjects_1 = self.subjects_1[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] self.subjects_1 = self.subjects_1 * quotient self.subjects_1 += [self.subjects_1[i] for i in new_indices] self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1) self.queue_train_set_1 = tio.Queue( subjects_dataset=self.dataset_1, max_length=queue_length, samples_per_volume=samples_per_volume, sampler=UniformSampler(patch_size), # sampler=LabelSampler(patch_size), num_workers=num_workers, shuffle_subjects=shuffle_subjects, shuffle_patches=shuffle_patches ) class dataset_iit_conresnet(Dataset): def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None): super(dataset_iit_conresnet, self).__init__() self.subjects_1 = [] image_dir_1 = data_dir + '/' + input1 image_dir_2 = data_dir + '/' + input2 if sup: mask_dir_1 = data_dir + '/mask' mask_dir_2 = data_dir + '/mask_res' for i in os.listdir(image_dir_1): image_path_1 = os.path.join(image_dir_1, i) image_path_2 = os.path.join(image_dir_2, i) if sup: mask_path_1 = os.path.join(mask_dir_1, i) mask_path_2 = os.path.join(mask_dir_2, i) subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), ID=i) else: subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i) self.subjects_1.append(subject_1) if num_images is not None: len_img_paths = len(self.subjects_1) quotient = num_images // len_img_paths remainder = num_images % len_img_paths if num_images <= len_img_paths: self.subjects_1 = self.subjects_1[:num_images] else: rand_indices = torch.randperm(len_img_paths).tolist() new_indices = rand_indices[:remainder] self.subjects_1 = self.subjects_1 * quotient self.subjects_1 += [self.subjects_1[i] for i in new_indices] self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1) self.queue_train_set_1 = tio.Queue( subjects_dataset=self.dataset_1, max_length=queue_length, samples_per_volume=samples_per_volume, sampler=UniformSampler(patch_size), # sampler=LabelSampler(patch_size), num_workers=num_workers, shuffle_subjects=shuffle_subjects, shuffle_patches=shuffle_patches ) ================================================ FILE: loss/__init__.py ================================================ ================================================ FILE: loss/loss_function.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.autograd import Variable import sys from torch.nn.modules.loss import _Loss class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss): def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs): super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index) self.aux = aux self.aux_weight = aux_weight def _aux_forward(self, output, target, **kwargs): # *preds, target = tuple(inputs) loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[0], target) for i in range(1, len(output)): aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[i], target) loss += self.aux_weight * aux_loss return loss def forward(self, output, target): # preds, target = tuple(inputs) # inputs = tuple(list(preds) + [target]) if self.aux: return self._aux_forward(output, target) else: return super(MixSoftmaxCrossEntropyLoss, self).forward(output, target) class BinaryDiceLoss(nn.Module): """Dice loss of binary class Args: smooth: A float number to smooth loss, and avoid NaN error, default: 1 p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 predict: A tensor of shape [N, *] target: A tensor of shape same with predict reduction: Reduction method to apply, return mean over batch if 'mean', return sum if 'sum', return a tensor of shape [N,] if 'none' Returns: Loss tensor according to arg reduction Raise: Exception if unexpected reduction """ def __init__(self, smooth=1, p=2, reduction='mean'): super(BinaryDiceLoss, self).__init__() self.smooth = smooth self.p = p self.reduction = reduction def forward(self, predict, target, valid_mask): assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" predict = predict.contiguous().view(predict.shape[0], -1) target = target.contiguous().view(target.shape[0], -1).float() valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1).float() num = torch.sum(torch.mul(predict, target) * valid_mask, dim=1) * 2 + self.smooth den = torch.sum((predict.pow(self.p) + target.pow(self.p)) * valid_mask, dim=1) + self.smooth loss = 1 - num / den if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() elif self.reduction == 'none': return loss else: raise Exception('Unexpected reduction {}'.format(self.reduction)) class DiceLoss(nn.Module): """Dice loss, need one hot encode input""" def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs): super(DiceLoss, self).__init__() self.kwargs = kwargs self.weight = weight self.ignore_index = ignore_index self.aux = aux self.aux_weight = aux_weight def _base_forward(self, predict, target, valid_mask): dice = BinaryDiceLoss(**self.kwargs) total_loss = 0 predict = F.softmax(predict, dim=1) for i in range(target.shape[-1]): if i != self.ignore_index: dice_loss = dice(predict[:, i], target[..., i], valid_mask) if self.weight is not None: assert self.weight.shape[0] == target.shape[1], \ 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) dice_loss *= self.weights[i] total_loss += dice_loss return total_loss / target.shape[-1] def _aux_forward(self, output, target, **kwargs): # *preds, target = tuple(inputs) valid_mask = (target != self.ignore_index).long() target_one_hot = F.one_hot(torch.clamp_min(target, 0)) loss = self._base_forward(output[0], target_one_hot, valid_mask) for i in range(1, len(output)): aux_loss = self._base_forward(output[i], target_one_hot, valid_mask) loss += self.aux_weight * aux_loss return loss def forward(self, output, target): # preds, target = tuple(inputs) # inputs = tuple(list(preds) + [target]) if self.aux: return self._aux_forward(output, target) else: valid_mask = (target != self.ignore_index).long() target_one_hot = F.one_hot(torch.clamp_min(target, 0)) return self._base_forward(output, target_one_hot, valid_mask) def softmax_mse_loss(input_logits, target_logits, sigmoid=False): """Takes softmax on both sides and returns MSE loss Note: - Returns the sum over all examples. Divide by the batch size afterwards if you want the mean. - Sends gradients to inputs but not the targets. """ assert input_logits.size() == target_logits.size() if sigmoid: input_softmax = torch.sigmoid(input_logits) target_softmax = torch.sigmoid(target_logits) else: input_softmax = F.softmax(input_logits, dim=1) target_softmax = F.softmax(target_logits, dim=1) mse_loss = (input_softmax-target_softmax)**2 return mse_loss def entropy_loss(p, C=2): # p N*C*W*H*D y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / torch.tensor(np.log(C)).cuda() ent = torch.mean(y1) return ent class BCELossBoud(nn.Module): def __init__(self, num_classes, weight=None, ignore_index=None, **kwargs): super(BCELossBoud, self).__init__() self.kwargs = kwargs self.weight = weight self.ignore_index = ignore_index self.num_classes = num_classes self.criterion = nn.BCEWithLogitsLoss() def weighted_BCE_cross_entropy(self, output, target, weights = None): if weights is not None: assert len(weights) == 2 output = torch.clamp(output, min=1e-3, max=1-1e-3) bce = weights[1] * (target * torch.log(output)) + weights[0] * ((1-target) * torch.log((1-output))) else: output = torch.clamp(output, min=1e-3, max=1 - 1e-3) bce = target * torch.log(output) + (1-target) * torch.log((1-output)) return torch.neg(torch.mean(bce)) def forward(self, predict, target): target_one_hot = F.one_hot(torch.clamp_min(target, 0), num_classes=self.num_classes).permute(0, 4, 1, 2, 3) predict = torch.softmax(predict, 1) bs, category, depth, width, heigt = target_one_hot.shape bce_loss = [] for i in range(predict.shape[1]): pred_i = predict[:,i] targ_i = target_one_hot[:,i] tt = np.log(depth * width * heigt / (target_one_hot[:, i].cpu().data.numpy().sum()+1)) bce_i = self.weighted_BCE_cross_entropy(pred_i, targ_i, weights=[1, tt]) bce_loss.append(bce_i) bce_loss = torch.stack(bce_loss) total_loss = bce_loss.mean() return total_loss class CustomKLLoss(_Loss): ''' KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N N is the total number of image voxels ''' def __init__(self, *args, **kwargs): super(CustomKLLoss, self).__init__() def forward(self, mean, std): return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean( torch.log(torch.mul(std, std))) - 1 def segmentation_loss(loss='CE', aux=False, **kwargs): if loss == 'dice' or loss == 'DICE': seg_loss = DiceLoss(aux=aux) elif loss == 'crossentropy' or loss == 'CE': seg_loss = MixSoftmaxCrossEntropyLoss(aux=aux) elif loss == 'bce': seg_loss = nn.BCELoss(size_average=True) elif loss == 'bcebound': seg_loss = BCELossBoud(num_classes=kwargs['num_classes']) else: print('sorry, the loss you input is not supported yet') sys.exit() return seg_loss # if __name__ == '__main__': # from models import * # criterion = segmentation_loss(loss='LOVASZ') # # criterion = nn.CrossEntropyLoss() # # model = unet(1, 2) # model.eval() # input = torch.rand(3, 1, 128, 128) # mask = torch.zeros(3, 128, 128).long() # # mask[:, 40:100, 30:60] = 1 # output = model(input) # # loss = criterion(output, mask) # print(loss) # # loss.requires_grad_(True) # # loss.backward() ================================================ FILE: models/__init__.py ================================================ # 2d from .networks_2d.xnet import XNet, XNet_1_1_m, XNet_1_2_m, XNet_2_1_m, XNet_3_2_m, XNet_2_3_m, XNet_3_3_m, XNet_sb from .networks_2d.unet import unet, r2_unet, attention_unet from .networks_2d.unet_plusplus import unet_plusplus from .networks_2d.hrnet import hrnet18, hrnet32, hrnet48, hrnet64 from .networks_2d.swinunet import swinunet from .networks_2d.unet_urpc import unet_urpc from .networks_2d.unet_cct import unet_cct from .networks_2d.resunet import res_unet from .networks_2d.resunet_plusplus import res_unet_plusplus from .networks_2d.u2net import u2net, u2net_small from .networks_2d.unet_3plus import unet_3plus, unet_3plus_ds, unet_3plus_ds_cgm from .networks_2d.wavesnet import wsegnet_vgg16_bn from .networks_2d.mwcnn import mwcnn from .networks_2d.aerial_lanenet import Aerial_LaneNet from .networks_2d.wds import WDS # 3d from .networks_3d.unet3d import unet3d, unet3d_min from .networks_3d.vnet import vnet from .networks_3d.res_unet3d import res_unet3d from .networks_3d.transbts import transbts from .networks_3d.cotr import cotr from .networks_3d.dmfnet import dmfnet from .networks_3d.conresnet import conresnet from .networks_3d.espnet3d import espnet3d from .networks_3d.unetr import unertr from .networks_3d.unet3d_urpc import unet3d_urpc from .networks_3d.unet3d_cct import unet3d_cct, unet3d_cct_min from .networks_3d.unet3d_dtc import unet3d_dtc from .networks_3d.xnet3d import xnet3d from .networks_3d.vnet_cct import vnet_cct from .networks_3d.vnet_dtc import vnet_dtc ================================================ FILE: models/getnetwork.py ================================================ import sys from models import * import torch.nn as nn def get_network(network, in_channels, num_classes, **kwargs): # 2d networks if network == 'xnet': net = XNet(in_channels, num_classes) elif network == 'xnet_sb': net = XNet_sb(in_channels, num_classes) elif network == 'xnet_1_1_m': net = XNet_1_1_m(in_channels, num_classes) elif network == 'xnet_1_2_m': net = XNet_1_2_m(in_channels, num_classes) elif network == 'xnet_2_1_m': net = XNet_2_1_m(in_channels, num_classes) elif network == 'xnet_3_2_m': net = XNet_3_2_m(in_channels, num_classes) elif network == 'xnet_2_3_m': net = XNet_2_3_m(in_channels, num_classes) elif network == 'xnet_3_3_m': net = XNet_3_3_m(in_channels, num_classes) elif network == 'unet': net = unet(in_channels, num_classes) elif network == 'unet_plusplus' or network == 'unet++': net = unet_plusplus(in_channels, num_classes) elif network == 'r2unet': net = r2_unet(in_channels, num_classes) elif network == 'attunet': net = attention_unet(in_channels, num_classes) elif network == 'hrnet18': net = hrnet18(in_channels, num_classes) elif network == 'hrnet48': net = hrnet48(in_channels, num_classes) elif network == 'resunet': net = res_unet(in_channels, num_classes) elif network == 'resunet++': net = res_unet_plusplus(in_channels, num_classes) elif network == 'u2net': net = u2net(in_channels, num_classes) elif network == 'u2net_s': net = u2net_small(in_channels, num_classes) elif network == 'unet3+': net = unet_3plus(in_channels, num_classes) elif network == 'unet3+_ds': net = unet_3plus_ds(in_channels, num_classes) elif network == 'unet3+_ds_cgm': net = unet_3plus_ds_cgm(in_channels, num_classes) elif network == 'swinunet': net = swinunet(num_classes, 224) # img_size = 224 elif network == 'unet_urpc': net = unet_urpc(in_channels, num_classes) elif network == 'unet_cct': net = unet_cct(in_channels, num_classes) elif network == 'wavesnet': net = wsegnet_vgg16_bn(in_channels, num_classes) elif network == 'mwcnn': net = mwcnn(in_channels, num_classes) elif network == 'alnet': net = Aerial_LaneNet(in_channels, num_classes) elif network == 'wds': net = WDS(in_channels, num_classes) # 3d networks elif network == 'xnet3d': net = xnet3d(in_channels, num_classes) elif network == 'unet3d': net = unet3d(in_channels, num_classes) elif network == 'unet3d_min': net = unet3d_min(in_channels, num_classes) elif network == 'unet3d_urpc': net = unet3d_urpc(in_channels, num_classes) elif network == 'unet3d_cct': net = unet3d_cct(in_channels, num_classes) elif network == 'unet3d_cct_min': net = unet3d_cct_min(in_channels, num_classes) elif network == 'unet3d_dtc': net = unet3d_dtc(in_channels, num_classes) elif network == 'vnet': net = vnet(in_channels, num_classes) elif network == 'vnet_cct': net = vnet_cct(in_channels, num_classes) elif network == 'vnet_dtc': net = vnet_dtc(in_channels, num_classes) elif network == 'resunet3d': net = res_unet3d(in_channels, num_classes) elif network == 'conresnet': net = conresnet(in_channels, num_classes, img_shape=kwargs['img_shape']) elif network == 'espnet3d': net = espnet3d(in_channels, num_classes) elif network == 'dmfnet': net = dmfnet(in_channels, num_classes) elif network == 'transbts': net = transbts(in_channels, num_classes, img_shape=kwargs['img_shape']) elif network == 'cotr': net = cotr(in_channels, num_classes) elif network == 'unertr': net = unertr(in_channels, num_classes, img_shape=kwargs['img_shape']) else: print('the network you have entered is not supported yet') sys.exit() return net ================================================ FILE: models/networks_2d/__init__.py ================================================ ================================================ FILE: models/networks_2d/aerial_lanenet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init import functools from torch.distributions.uniform import Uniform import numpy as np class basic_block(nn.Module): def __init__(self, ch_in, ch_out): super(basic_block, self).__init__() self.block = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False), nn.ReLU(inplace=True)) def forward(self, x): x = self.block(x) return x class Aerial_LaneNet(nn.Module): def __init__(self, in_channels, num_classes): super(Aerial_LaneNet, self).__init__() l1, l2, l3, l4, l5 = 64, 128, 256, 512, 512 dropout = 0.2 # e1 self.conv1_1 = basic_block(in_channels, l1) self.conv1_2 = basic_block(l1, l1) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # e2 self.conv2_1 = basic_block(l1+3, l2) self.conv2_2 = basic_block(l2, l2) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # e3 self.conv3_1 = basic_block(l2+3, l3) self.conv3_2 = basic_block(l3, l3) self.conv3_3 = basic_block(l3, l3) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # e4 self.conv4_1 = basic_block(l3+3, l4) self.conv4_2 = basic_block(l4, l4) self.conv4_3 = basic_block(l4, l4) self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # e5 self.conv5_1 = basic_block(l4+3, l5) self.conv5_2 = basic_block(l5, l5) self.conv5_3 = basic_block(l5, l5) self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # e6 self.conv6_1 = basic_block(l5, 4096) self.drop6_1 = nn.Dropout2d(dropout) self.conv6_2 = basic_block(4096, 4096) self.drop6_2 = nn.Dropout2d(dropout) self.conv6_3 = nn.ConvTranspose2d(4096, l5, kernel_size=4, stride=2, padding=1, bias=False) # d4 self.conv4_4 = basic_block(2*l5, l5) self.drop4_4 = nn.Dropout2d(dropout) self.conv4_5 = nn.ConvTranspose2d(l5, l3, kernel_size=4, stride=2, padding=1, bias=False) # d3 self.conv3_4 = basic_block(2*l3, l3) self.drop3_4 = nn.Dropout2d(dropout) self.conv3_5 = nn.ConvTranspose2d(l3, l2, kernel_size=4, stride=2, padding=1, bias=False) # d2 self.conv2_4 = basic_block(2*l2, l2) self.drop2_4 = nn.Dropout2d(dropout) self.conv2_5 = nn.ConvTranspose2d(l2, l1, kernel_size=4, stride=2, padding=1, bias=False) # d1 self.conv1_3 = basic_block(2*l1, l1) self.drop1_3 = nn.Dropout2d(dropout) self.conv1_4 = nn.ConvTranspose2d(l1, num_classes, kernel_size=4, stride=2, padding=1, bias=False) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, x_wavelet_1, x_wavelet_2, x_wavelet_3, x_wavelet_4): x1 = self.conv1_1(x) x1 = self.conv1_2(x1) x1 = self.pool1(x1) x2 = torch.cat((x1, x_wavelet_1), dim=1) x2 = self.conv2_1(x2) x2 = self.conv2_2(x2) x2 = self.pool2(x2) x3 = torch.cat((x2, x_wavelet_2), dim=1) x3 = self.conv3_1(x3) x3 = self.conv3_2(x3) x3 = self.conv3_3(x3) x3 = self.pool3(x3) x4 = torch.cat((x3, x_wavelet_3), dim=1) x4 = self.conv4_1(x4) x4 = self.conv4_2(x4) x4 = self.conv4_3(x4) x4 = self.pool4(x4) x5 = torch.cat((x4, x_wavelet_4), dim=1) x5 = self.conv5_1(x5) x5 = self.conv5_2(x5) x5 = self.conv5_3(x5) x5 = self.pool5(x5) x6 = self.conv6_1(x5) x6 = self.drop6_1(x6) x6 = self.conv6_2(x6) x6 = self.drop6_2(x6) x6 = self.conv6_3(x6) x5 = torch.cat((x6, x4), dim=1) x5 = self.conv4_4(x5) x5 = self.drop4_4(x5) x5 = self.conv4_5(x5) x4 = torch.cat((x5, x3), dim=1) x4 = self.conv3_4(x4) x4 = self.drop3_4(x4) x4 = self.conv3_5(x4) x3 = torch.cat((x4, x2), dim=1) x3 = self.conv2_4(x3) x3 = self.drop2_4(x3) x3 = self.conv2_5(x3) x2 = torch.cat((x3, x1), dim=1) x2 = self.conv1_3(x2) x2 = self.drop1_3(x2) x2 = self.conv1_4(x2) return x2 # if __name__ == '__main__': # from loss.loss_function import segmentation_loss # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 128, 128).long() # model = Aerial_LaneNet(1, 5) # model.train() # input1 = torch.rand(2, 1, 128, 128) # input2 = torch.rand(2, 3, 64, 64) # input3 = torch.rand(2, 3, 32, 32) # input4 = torch.rand(2, 3, 16, 16) # input5 = torch.rand(2, 3, 8, 8) # # y = model(input1, input2, input3, input4, input5) # loss_train = criterion(y, mask) # loss_train.backward() # # print(output) # print(y.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: models/networks_2d/hrnet.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import logging import functools import numpy as np import torch import torch.nn as nn import torch._utils import torch.nn.functional as F from torch.nn import init try: from .sync_bn.inplace_abn.bn import InPlaceABNSync BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') except: BatchNorm2d = nn.BatchNorm2d BN_MOMENTUM = 0.01 logger = logging.getLogger(__name__) model_urls = { 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', } import sys try: from urllib import urlretrieve except ImportError: from urllib.request import urlretrieve def load_url(url, model_dir='./pretrained', map_location=None): if not os.path.exists(model_dir): os.makedirs(model_dir) filename = url.split('/')[-1] cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) urlretrieve(url, cached_file) return torch.load(cached_file, map_location=map_location) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.conv2 = conv3x3(planes, planes) self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True): super(HighResolutionModule, self).__init__() self._check_branches( num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method self.num_branches = num_branches self.multi_scale_output = multi_scale_output self.branches = self._make_branches( num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=False) def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( num_branches, len(num_blocks)) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( num_branches, len(num_channels)) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( num_branches, len(num_inchannels)) logger.error(error_msg) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): branches.append( self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers) def get_num_inchannels(self): return self.num_inchannels def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): if i == j: y = y + x[j] elif j > i: width_output = x[i].shape[-1] height_output = x[i].shape[-2] y = y + F.interpolate( self.fuse_layers[i][j](x[j]), size=[height_output, width_output], mode='bilinear',align_corners=False) else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse blocks_dict = { 'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck } class HighResolutionNet(nn.Module): def __init__(self, in_channels, extra, num_classes,**kwargs): super(HighResolutionNet, self).__init__() # stem net self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=False) self.stage1_cfg = extra['STAGE1'] num_channels = self.stage1_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage1_cfg['BLOCK']] num_blocks = self.stage1_cfg['NUM_BLOCKS'] self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) stage1_out_channel = block.expansion * num_channels self.stage2_cfg = extra['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer( [stage1_out_channel], num_channels) self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) self.stage3_cfg = extra['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) self.stage4_cfg = extra['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer( pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multi_scale_output=True) last_inp_channels = int(np.sum(pre_stage_channels)) self.last_layer = nn.Sequential( nn.Conv2d( in_channels=last_inp_channels, out_channels=last_inp_channels, kernel_size=1, stride=1, padding=0), BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), nn.ReLU(inplace=False), nn.Conv2d( in_channels=last_inp_channels, out_channels=num_classes, kernel_size=extra['FINAL_CONV_KERNEL'], stride=1, padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0) ) def _make_transition_layer( self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), BatchNorm2d( num_channels_cur_layer[i], momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i - num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d( inchannels, outchannels, 3, 2, 1, bias=False), BatchNorm2d(outchannels, momentum=BN_MOMENTUM), nn.ReLU(inplace=False))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(inplanes, planes, stride, downsample)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(inplanes, planes)) return nn.Sequential(*layers) def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): num_modules = layer_config['NUM_MODULES'] num_branches = layer_config['NUM_BRANCHES'] num_blocks = layer_config['NUM_BLOCKS'] num_channels = layer_config['NUM_CHANNELS'] block = blocks_dict[layer_config['BLOCK']] fuse_method = layer_config['FUSE_METHOD'] modules = [] for i in range(num_modules): # multi_scale_output is only used last module if not multi_scale_output and i == num_modules - 1: reset_multi_scale_output = False else: reset_multi_scale_output = True modules.append( HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) ) num_inchannels = modules[-1].get_num_inchannels() return nn.Sequential(*modules), num_inchannels def forward(self, x): size =x.shape[2:] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.layer1(x) x_list = [] for i in range(self.stage2_cfg['NUM_BRANCHES']): if self.transition1[i] is not None: x_list.append(self.transition1[i](x)) else: x_list.append(x) y_list = self.stage2(x_list) x_list = [] for i in range(self.stage3_cfg['NUM_BRANCHES']): if self.transition2[i] is not None: if i < self.stage2_cfg['NUM_BRANCHES']: x_list.append(self.transition2[i](y_list[i])) else: x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] for i in range(self.stage4_cfg['NUM_BRANCHES']): if self.transition3[i] is not None: if i < self.stage3_cfg['NUM_BRANCHES']: x_list.append(self.transition3[i](y_list[i])) else: x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i]) x = self.stage4(x_list) # Upsampling x0_h, x0_w = x[0].size(2), x[0].size(3) x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) x = torch.cat([x[0], x1, x2, x3], 1) x = self.last_layer(x) x = F.interpolate(x, size=size, mode='bilinear', align_corners=False) # outputs = [] # outputs.append(x) # return outputs return x def init_weights(self, pretrained='', ): logger.info('=> init weights from normal distribution') for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) elif isinstance(m, InPlaceABNSync): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if os.path.isfile(pretrained): pretrained_dict = torch.load(pretrained) logger.info('=> loading pretrained model {}'.format(pretrained)) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} for k, _ in pretrained_dict.items(): logger.info( '=> loading {} pretrained model {}'.format(k, pretrained)) model_dict.update(pretrained_dict) self.load_state_dict(model_dict) # class HRNet(nn.Module): # def __init__(self, in_channels, extra, num_classes, **kwargs): # super(HRNet, self).__init__() # self.branch1 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra) # self.branch2 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra) # # def forward(self, data, step=1): # if not self.training: # pred1 = self.branch1(data) # return pred1 # # if step == 1: # return self.branch1(data) # elif step == 2: # return self.branch2(data) extra_18 = { 'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'}, 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (18, 36), 'FUSE_METHOD': 'SUM'}, 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (18, 36, 72), 'FUSE_METHOD': 'SUM'}, 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (18, 36, 72, 144), 'FUSE_METHOD': 'SUM'}, 'FINAL_CONV_KERNEL': 1 } extra_32 = { 'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'}, 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (32, 64), 'FUSE_METHOD': 'SUM'}, 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (32, 64, 128), 'FUSE_METHOD': 'SUM'}, 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (32, 64, 128, 256), 'FUSE_METHOD': 'SUM'}, 'FINAL_CONV_KERNEL': 1 } extra_48 = { 'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'}, 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, 'FINAL_CONV_KERNEL': 1 } extra_64 = { 'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'}, 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (64, 128), 'FUSE_METHOD': 'SUM'}, 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (64, 128, 256), 'FUSE_METHOD': 'SUM'}, 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (64, 128, 256, 512), 'FUSE_METHOD': 'SUM'}, 'FINAL_CONV_KERNEL': 1 } def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # def hrnet18(in_channels, num_classes): # model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18) # return model # # def hrnet32(in_channels, num_classes): # model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32) # return model # # def hrnet48(in_channels, num_classes): # model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48) # return model # # def hrnet64(in_channels, num_classes): # model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64) # return model def hrnet18(in_channels, num_classes): model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18) init_weights(model, 'kaiming') return model def hrnet32(in_channels, num_classes): model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32) init_weights(model, 'kaiming') return model def hrnet48(in_channels, num_classes): model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48) init_weights(model, 'kaiming') return model def hrnet64(in_channels, num_classes): model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = hrnet48(1,10) # total = sum([param.nelement() for param in model.parameters()]) # from thop import profile,clever_format # # input = torch.randn(1, 1, 128, 128) # flops, params = profile(model, inputs=(input, )) # macs, params = clever_format([flops, params], "%.3f") # print(macs) # print(params) # print(total) # model.eval() # input = torch.rand(1,1,256,256) # output = model(input) # output = output[0].data.cpu().numpy() # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/mwcnn.py ================================================ import torch import torch.nn as nn import scipy.io as sio import math import torch.nn.functional as F from torch.autograd import Variable def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2) + dilation - 1, bias=bias, dilation=dilation) def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, groups=groups) # def shuffle_channel() def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x def pixel_down_shuffle(x, downsacale_factor): batchsize, num_channels, height, width = x.size() out_height = height // downsacale_factor out_width = width // downsacale_factor input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width, downsacale_factor) num_channels *= downsacale_factor ** 2 unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() return unshuffle_out.view(batchsize, num_channels, out_height, out_width) def sp_init(x): x01 = x[:, :, 0::2, :] x02 = x[:, :, 1::2, :] x_LL = x01[:, :, :, 0::2] x_HL = x02[:, :, :, 0::2] x_LH = x01[:, :, :, 1::2] x_HH = x02[:, :, :, 1::2] return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) def dwt_init(x): x01 = x[:, :, 0::2, :] / 2 x02 = x[:, :, 1::2, :] / 2 x1 = x01[:, :, :, 0::2] x2 = x02[:, :, :, 0::2] x3 = x01[:, :, :, 1::2] x4 = x02[:, :, :, 1::2] x_LL = x1 + x2 + x3 + x4 x_HL = -x1 - x2 + x3 + x4 x_LH = -x1 + x2 - x3 + x4 x_HH = x1 - x2 - x3 + x4 return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) def iwt_init(x): r = 2 in_batch, in_channel, in_height, in_width = x.size() # print([in_batch, in_channel, in_height, in_width]) out_batch, out_channel, out_height, out_width = in_batch, int( in_channel / (r ** 2)), r * in_height, r * in_width x1 = x[:, 0:out_channel, :, :] / 2 x2 = x[:, out_channel:out_channel * 2, :, :] / 2 x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float() h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 return h class Channel_Shuffle(nn.Module): def __init__(self, conv_groups): super(Channel_Shuffle, self).__init__() self.conv_groups = conv_groups self.requires_grad = False def forward(self, x): return channel_shuffle(x, self.conv_groups) class SP(nn.Module): def __init__(self): super(SP, self).__init__() self.requires_grad = False def forward(self, x): return sp_init(x) class Pixel_Down_Shuffle(nn.Module): def __init__(self): super(Pixel_Down_Shuffle, self).__init__() self.requires_grad = False def forward(self, x): return pixel_down_shuffle(x, 2) class DWT(nn.Module): def __init__(self): super(DWT, self).__init__() self.requires_grad = False def forward(self, x): return dwt_init(x) class IWT(nn.Module): def __init__(self): super(IWT, self).__init__() self.requires_grad = False def forward(self, x): return iwt_init(x) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False if sign == -1: self.create_graph = False self.volatile = True class MeanShift2(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift2, self).__init__(4, 4, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(4).view(4, 4, 1, 1) self.weight.data.div_(std.view(4, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False if sign == -1: self.volatile = True class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=False, act=nn.ReLU(True)): m = [nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class BBlock(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(BBlock, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x).mul(self.res_scale) return x class DBlock_com(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_com, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class DBlock_inv(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_inv, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class DBlock_com1(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_com1, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class DBlock_inv1(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_inv1, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class DBlock_com2(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_com2, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class DBlock_inv2(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DBlock_inv2, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2)) if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x) return x class ShuffleBlock(nn.Module): def __init__( self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, conv_groups=1): super(ShuffleBlock, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias)) m.append(Channel_Shuffle(conv_groups)) if bn: m.append(nn.BatchNorm2d(out_channels)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x).mul(self.res_scale) return x class DWBlock(nn.Module): def __init__( self, conv, conv1, in_channels, out_channels, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(DWBlock, self).__init__() m = [] m.append(conv(in_channels, out_channels, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(out_channels)) m.append(act) m.append(conv1(in_channels, out_channels, 1, bias=bias)) if bn: m.append(nn.BatchNorm2d(out_channels)) m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): x = self.body(x).mul(self.res_scale) return x class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Block(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(Block, self).__init__() m = [] for i in range(4): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) # res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class MWCNN(nn.Module): def __init__(self, in_channels, num_classes, conv=default_conv): super(MWCNN, self).__init__() kernel_size = 3 self.scale_idx = 0 n_feats = 64 act = nn.ReLU(True) self.DWT = DWT() self.IWT = IWT() n = 1 m_head = [BBlock(conv, in_channels, n_feats, kernel_size, act=act)] d_l0 = [] d_l0.append(DBlock_com1(conv, n_feats, n_feats, kernel_size, act=act, bn=False)) d_l1 = [BBlock(conv, n_feats * 4, n_feats * 2, kernel_size, act=act, bn=False)] d_l1.append(DBlock_com1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False)) d_l2 = [] d_l2.append(BBlock(conv, n_feats * 8, n_feats * 4, kernel_size, act=act, bn=False)) d_l2.append(DBlock_com1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False)) pro_l3 = [] pro_l3.append(BBlock(conv, n_feats * 16, n_feats * 8, kernel_size, act=act, bn=False)) pro_l3.append(DBlock_com(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False)) pro_l3.append(DBlock_inv(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False)) pro_l3.append(BBlock(conv, n_feats * 8, n_feats * 16, kernel_size, act=act, bn=False)) i_l2 = [DBlock_inv1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False)] i_l2.append(BBlock(conv, n_feats * 4, n_feats * 8, kernel_size, act=act, bn=False)) i_l1 = [DBlock_inv1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False)] i_l1.append(BBlock(conv, n_feats * 2, n_feats * 4, kernel_size, act=act, bn=False)) i_l0 = [DBlock_inv1(conv, n_feats, n_feats, kernel_size, act=act, bn=False)] m_tail = [conv(n_feats, num_classes, kernel_size)] self.head = nn.Sequential(*m_head) self.d_l2 = nn.Sequential(*d_l2) self.d_l1 = nn.Sequential(*d_l1) self.d_l0 = nn.Sequential(*d_l0) self.pro_l3 = nn.Sequential(*pro_l3) self.i_l2 = nn.Sequential(*i_l2) self.i_l1 = nn.Sequential(*i_l1) self.i_l0 = nn.Sequential(*i_l0) self.tail = nn.Sequential(*m_tail) def forward(self, x): x0 = self.d_l0(self.head(x)) x1 = self.d_l1(self.DWT(x0)) x2 = self.d_l2(self.DWT(x1)) x_ = self.IWT(self.pro_l3(self.DWT(x2))) + x2 x_ = self.IWT(self.i_l2(x_)) + x1 x_ = self.IWT(self.i_l1(x_)) + x0 x_ = self.tail(self.i_l0(x_)) return x_ def set_scale(self, scale_idx): self.scale_idx = scale_idx def mwcnn(in_channels, num_classes): model = MWCNN(in_channels, num_classes) return model # if __name__ == '__main__': # from loss.loss_function import segmentation_loss # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 128, 128).long() # model = mwcnn(1, 2) # model.train() # input1 = torch.rand(2, 1, 128, 128) # y = model(input1) # loss_train = criterion(y, mask) # loss_train.backward() # # print(output) # print(y.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: models/networks_2d/resunet.py ================================================ import torch import torch.nn as nn from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class ResidualConv(nn.Module): def __init__(self, input_dim, output_dim, stride, padding): super(ResidualConv, self).__init__() self.conv_block = nn.Sequential( nn.BatchNorm2d(input_dim), nn.ReLU(), nn.Conv2d( input_dim, output_dim, kernel_size=3, stride=stride, padding=padding ), nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), ) self.conv_skip = nn.Sequential( nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(output_dim), ) def forward(self, x): return self.conv_block(x) + self.conv_skip(x) class Upsample(nn.Module): def __init__(self, input_dim, output_dim, kernel, stride): super(Upsample, self).__init__() self.upsample = nn.ConvTranspose2d( input_dim, output_dim, kernel_size=kernel, stride=stride ) def forward(self, x): return self.upsample(x) class ResUnet(nn.Module): def __init__(self, in_channels, num_classes, filters=[64, 128, 256, 512]): super(ResUnet, self).__init__() self.input_layer = nn.Sequential( nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), nn.BatchNorm2d(filters[0]), nn.ReLU(), nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), ) self.input_skip = nn.Sequential( nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) ) self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1) self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) self.bridge = ResidualConv(filters[2], filters[3], 2, 1) self.upsample_1 = Upsample(filters[3], filters[3], 2, 2) self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) self.upsample_2 = Upsample(filters[2], filters[2], 2, 2) self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) self.upsample_3 = Upsample(filters[1], filters[1], 2, 2) self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1) self.output_layer = nn.Conv2d(filters[0], num_classes, 1, 1) def forward(self, x): # Encode x1 = self.input_layer(x) + self.input_skip(x) x2 = self.residual_conv_1(x1) x3 = self.residual_conv_2(x2) # Bridge x4 = self.bridge(x3) # Decode x4 = self.upsample_1(x4) x5 = torch.cat([x4, x3], dim=1) x6 = self.up_residual_conv1(x5) x6 = self.upsample_2(x6) x7 = torch.cat([x6, x2], dim=1) x8 = self.up_residual_conv2(x7) x8 = self.upsample_3(x8) x9 = torch.cat([x8, x1], dim=1) x10 = self.up_residual_conv3(x9) output = self.output_layer(x10) return output def res_unet(in_channels, num_classes): model = ResUnet(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = res_unet(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/resunet_plusplus.py ================================================ import torch.nn as nn import torch from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class ResidualConv(nn.Module): def __init__(self, input_dim, output_dim, stride, padding): super(ResidualConv, self).__init__() self.conv_block = nn.Sequential( nn.BatchNorm2d(input_dim), nn.ReLU(), nn.Conv2d( input_dim, output_dim, kernel_size=3, stride=stride, padding=padding ), nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), ) self.conv_skip = nn.Sequential( nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(output_dim), ) def forward(self, x): return self.conv_block(x) + self.conv_skip(x) class Upsample(nn.Module): def __init__(self, input_dim, output_dim, kernel, stride): super(Upsample, self).__init__() self.upsample = nn.ConvTranspose2d( input_dim, output_dim, kernel_size=kernel, stride=stride ) def forward(self, x): return self.upsample(x) class Squeeze_Excite_Block(nn.Module): def __init__(self, channel, reduction=16): super(Squeeze_Excite_Block, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid(), ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class ASPP(nn.Module): def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): super(ASPP, self).__init__() self.aspp_block1 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.aspp_block2 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.aspp_block3 = nn.Sequential( nn.Conv2d( in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] ), nn.ReLU(inplace=True), nn.BatchNorm2d(out_dims), ) self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) self._init_weights() def forward(self, x): x1 = self.aspp_block1(x) x2 = self.aspp_block2(x) x3 = self.aspp_block3(x) out = torch.cat([x1, x2, x3], dim=1) return self.output(out) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Upsample_(nn.Module): def __init__(self, scale=2): super(Upsample_, self).__init__() self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale, align_corners=True) def forward(self, x): return self.upsample(x) class AttentionBlock(nn.Module): def __init__(self, input_encoder, input_decoder, output_dim): super(AttentionBlock, self).__init__() self.conv_encoder = nn.Sequential( nn.BatchNorm2d(input_encoder), nn.ReLU(), nn.Conv2d(input_encoder, output_dim, 3, padding=1), nn.MaxPool2d(2, 2), ) self.conv_decoder = nn.Sequential( nn.BatchNorm2d(input_decoder), nn.ReLU(), nn.Conv2d(input_decoder, output_dim, 3, padding=1), ) self.conv_attn = nn.Sequential( nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, 1, 1), ) def forward(self, x1, x2): out = self.conv_encoder(x1) + self.conv_decoder(x2) out = self.conv_attn(out) return out * x2 class ResUnetPlusPlus(nn.Module): def __init__(self, in_channels, num_classes, filters=[32, 64, 128, 256, 512]): super(ResUnetPlusPlus, self).__init__() self.input_layer = nn.Sequential( nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1), nn.BatchNorm2d(filters[0]), nn.ReLU(), nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), ) self.input_skip = nn.Sequential( nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) ) self.squeeze_excite1 = Squeeze_Excite_Block(filters[0]) self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1) self.squeeze_excite2 = Squeeze_Excite_Block(filters[1]) self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1) self.squeeze_excite3 = Squeeze_Excite_Block(filters[2]) self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1) self.aspp_bridge = ASPP(filters[3], filters[4]) self.attn1 = AttentionBlock(filters[2], filters[4], filters[4]) self.upsample1 = Upsample_(2) self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1) self.attn2 = AttentionBlock(filters[1], filters[3], filters[3]) self.upsample2 = Upsample_(2) self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1) self.attn3 = AttentionBlock(filters[0], filters[2], filters[2]) self.upsample3 = Upsample_(2) self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1) self.aspp_out = ASPP(filters[1], filters[0]) self.output_layer = nn.Conv2d(filters[0], num_classes, 1) def forward(self, x): x1 = self.input_layer(x) + self.input_skip(x) x2 = self.squeeze_excite1(x1) x2 = self.residual_conv1(x2) x3 = self.squeeze_excite2(x2) x3 = self.residual_conv2(x3) x4 = self.squeeze_excite3(x3) x4 = self.residual_conv3(x4) x5 = self.aspp_bridge(x4) x6 = self.attn1(x3, x5) x6 = self.upsample1(x6) x6 = torch.cat([x6, x3], dim=1) x6 = self.up_residual_conv1(x6) x7 = self.attn2(x2, x6) x7 = self.upsample2(x7) x7 = torch.cat([x7, x2], dim=1) x7 = self.up_residual_conv2(x7) x8 = self.attn3(x1, x7) x8 = self.upsample3(x8) x8 = torch.cat([x8, x1], dim=1) x8 = self.up_residual_conv3(x8) x9 = self.aspp_out(x8) out = self.output_layer(x9) return out def res_unet_plusplus(in_channels, num_classes): model = ResUnetPlusPlus(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = res_unet_plusplus(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/swinunet.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange import torch.utils.checkpoint as checkpoint logger = logging.getLogger(__name__) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( ).view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - \ coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - \ 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1).contiguous()) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).contiguous().reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 # nW, window_size, window_size, 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill( attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll( x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows # nW*B, window_size, window_size, C x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size*window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # W-MSA/SW-MSA # nW*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask) # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse( attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=( self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class PatchExpand(nn.Module): def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.expand = nn.Linear( dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) x = x.view(B, -1, C//4) x = self.norm(x) return x class FinalPatchExpand_X4(nn.Module): def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.dim_scale = dim_scale self.expand = nn.Linear(dim, 16*dim, bias=False) self.output_dim = dim self.norm = norm_layer(self.output_dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) x = x.view(B, -1, self.output_dim) x = self.norm(x) return x class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if ( i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance( drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class BasicLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if ( i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance( drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if upsample is not None: self.upsample = PatchExpand( input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) else: self.upsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.upsample is not None: x = self.upsample(x) return x class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * \ (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerSys(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, final_upsample="expand_first", **kwargs): super().__init__() print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, depths_decoder, drop_path_rate, num_classes)) self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.num_features_up = int(embed_dim * 2) self.mlp_ratio = mlp_ratio self.final_upsample = final_upsample # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build encoder and bottleneck layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum( depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if ( i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) # build decoder layers self.layers_up = nn.ModuleList() self.concat_back_dim = nn.ModuleList() for i_layer in range(self.num_layers): concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() if i_layer == 0: layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) else: layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), depth=depths[( self.num_layers-1-i_layer)], num_heads=num_heads[( self.num_layers-1-i_layer)], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:( self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], norm_layer=norm_layer, upsample=PatchExpand if ( i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers_up.append(layer_up) self.concat_back_dim.append(concat_linear) self.norm = norm_layer(self.num_features) self.norm_up = norm_layer(self.embed_dim) if self.final_upsample == "expand_first": print("---final upsample expand_first---") self.up = FinalPatchExpand_X4(input_resolution=( img_size//patch_size, img_size//patch_size), dim_scale=4, dim=embed_dim) self.output = nn.Conv2d( in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} #Encoder and Bottleneck def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) x_downsample = [] for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B L C return x, x_downsample # Dencoder and Skip connection def forward_up_features(self, x, x_downsample): for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: x = torch.cat([x, x_downsample[3-inx]], -1) x = self.concat_back_dim[inx](x) x = layer_up(x) x = self.norm_up(x) # B L C return x def up_x4(self, x): H, W = self.patches_resolution B, L, C = x.shape assert L == H*W, "input features has wrong size" if self.final_upsample == "expand_first": x = self.up(x) x = x.view(B, 4*H, 4*W, -1) x = x.permute(0, 3, 1, 2).contiguous() # B,C,H,W x = self.output(x) return x def forward(self, x): x, x_downsample = self.forward_features(x) x = self.forward_up_features(x, x_downsample) x = self.up_x4(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * \ self.patches_resolution[0] * \ self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops class SwinUnet(nn.Module): def __init__(self, num_classes, img_size, zero_head=False, vis=False): super(SwinUnet, self).__init__() self.num_classes = num_classes self.zero_head = zero_head self.swin_unet = SwinTransformerSys(img_size=img_size, patch_size=4, num_classes=num_classes, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4, qkv_bias=True, qk_scale=False, drop_rate=0.0, drop_path_rate=0.1, ape=False, patch_norm=True, use_checkpoint=False) def forward(self, x): if x.size()[1] == 1: x = x.repeat(1,3,1,1) logits = self.swin_unet(x) return logits def load_from(self, config): pretrained_path = config.MODEL.PRETRAIN_CKPT if pretrained_path is not None: print("pretrained_path:{}".format(pretrained_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_dict = torch.load(pretrained_path, map_location=device) if "model" not in pretrained_dict: print("---start load pretrained modle by splitting---") pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} for k in list(pretrained_dict.keys()): if "output" in k: print("delete key:{}".format(k)) del pretrained_dict[k] msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) # print(msg) return pretrained_dict = pretrained_dict['model'] print("---start load pretrained modle of swin encoder---") model_dict = self.swin_unet.state_dict() full_dict = copy.deepcopy(pretrained_dict) for k, v in pretrained_dict.items(): if "layers." in k: current_layer_num = 3-int(k[7:8]) current_k = "layers_up." + str(current_layer_num) + k[8:] full_dict.update({current_k:v}) for k in list(full_dict.keys()): if k in model_dict: if full_dict[k].shape != model_dict[k].shape: print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) del full_dict[k] msg = self.swin_unet.load_state_dict(full_dict, strict=False) # print(msg) else: print("none pretrain") def swinunet(num_classes, img_size): model = SwinUnet(num_classes, img_size=img_size) return model if __name__ == '__main__': model = swinunet(10, 224) model.eval() input = torch.rand(2,1,224,224) output = model(input) output = output.data.cpu().numpy() # print(output) print(output.shape) ================================================ FILE: models/networks_2d/u2net.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class REBNCONV(nn.Module): def __init__(self,in_ch=3,out_ch=3,dirate=1): super(REBNCONV,self).__init__() self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate) self.bn_s1 = nn.BatchNorm2d(out_ch) self.relu_s1 = nn.ReLU(inplace=True) def forward(self,x): hx = x xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) return xout ## upsample tensor 'src' to have the same spatial size with tensor 'tar' def _upsample_like(src,tar): src = F.interpolate(src,size=tar.shape[2:],mode='bilinear', align_corners=True) return src ### RSU-7 ### class RSU7(nn.Module):#UNet07DRES(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU7,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx = self.pool5(hx5) hx6 = self.rebnconv6(hx) hx7 = self.rebnconv7(hx6) hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) hx6dup = _upsample_like(hx6d,hx5) hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-6 ### class RSU6(nn.Module):#UNet06DRES(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU6,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx = self.pool4(hx4) hx5 = self.rebnconv5(hx) hx6 = self.rebnconv6(hx5) hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-5 ### class RSU5(nn.Module):#UNet05DRES(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU5,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx = self.pool3(hx3) hx4 = self.rebnconv4(hx) hx5 = self.rebnconv5(hx4) hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-4 ### class RSU4(nn.Module):#UNet04DRES(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx = self.pool1(hx1) hx2 = self.rebnconv2(hx) hx = self.pool2(hx2) hx3 = self.rebnconv3(hx) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) return hx1d + hxin ### RSU-4F ### class RSU4F(nn.Module):#UNet04FRES(nn.Module): def __init__(self, in_ch=3, mid_ch=12, out_ch=3): super(RSU4F,self).__init__() self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) def forward(self,x): hx = x hxin = self.rebnconvin(hx) hx1 = self.rebnconv1(hxin) hx2 = self.rebnconv2(hx1) hx3 = self.rebnconv3(hx2) hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) return hx1d + hxin ##### U^2-Net #### class U2NET(nn.Module): def __init__(self,in_ch=3,out_ch=1): super(U2NET,self).__init__() self.stage1 = RSU7(in_ch,32,64) self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage2 = RSU6(64,32,128) self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage3 = RSU5(128,64,256) self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage4 = RSU4(256,128,512) self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage5 = RSU4F(512,256,512) self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage6 = RSU4F(512,256,512) # decoder self.stage5d = RSU4F(1024,256,512) self.stage4d = RSU4(1024,128,256) self.stage3d = RSU5(512,64,128) self.stage2d = RSU6(256,32,64) self.stage1d = RSU7(128,16,64) self.side1 = nn.Conv2d(64,out_ch,3,padding=1) self.side2 = nn.Conv2d(64,out_ch,3,padding=1) self.side3 = nn.Conv2d(128,out_ch,3,padding=1) self.side4 = nn.Conv2d(256,out_ch,3,padding=1) self.side5 = nn.Conv2d(512,out_ch,3,padding=1) self.side6 = nn.Conv2d(512,out_ch,3,padding=1) self.outconv = nn.Conv2d(6*out_ch,out_ch,1) def forward(self,x): hx = x #stage 1 hx1 = self.stage1(hx) hx = self.pool12(hx1) #stage 2 hx2 = self.stage2(hx) hx = self.pool23(hx2) #stage 3 hx3 = self.stage3(hx) hx = self.pool34(hx3) #stage 4 hx4 = self.stage4(hx) hx = self.pool45(hx4) #stage 5 hx5 = self.stage5(hx) hx = self.pool56(hx5) #stage 6 hx6 = self.stage6(hx) hx6up = _upsample_like(hx6,hx5) #-------------------- decoder -------------------- hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) #side output d1 = self.side1(hx1d) d2 = self.side2(hx2d) d2 = _upsample_like(d2,d1) d3 = self.side3(hx3d) d3 = _upsample_like(d3,d1) d4 = self.side4(hx4d) d4 = _upsample_like(d4,d1) d5 = self.side5(hx5d) d5 = _upsample_like(d5,d1) d6 = self.side6(hx6) d6 = _upsample_like(d6,d1) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) return d0, d1, d2, d3, d4, d5, d6 ### U^2-Net small ### class U2NETP(nn.Module): def __init__(self,in_ch=3,out_ch=1): super(U2NETP,self).__init__() self.stage1 = RSU7(in_ch,16,64) self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage2 = RSU6(64,16,64) self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage3 = RSU5(64,16,64) self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage4 = RSU4(64,16,64) self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage5 = RSU4F(64,16,64) self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) self.stage6 = RSU4F(64,16,64) # decoder self.stage5d = RSU4F(128,16,64) self.stage4d = RSU4(128,16,64) self.stage3d = RSU5(128,16,64) self.stage2d = RSU6(128,16,64) self.stage1d = RSU7(128,16,64) self.side1 = nn.Conv2d(64,out_ch,3,padding=1) self.side2 = nn.Conv2d(64,out_ch,3,padding=1) self.side3 = nn.Conv2d(64,out_ch,3,padding=1) self.side4 = nn.Conv2d(64,out_ch,3,padding=1) self.side5 = nn.Conv2d(64,out_ch,3,padding=1) self.side6 = nn.Conv2d(64,out_ch,3,padding=1) self.outconv = nn.Conv2d(6*out_ch,out_ch,1) def forward(self,x): hx = x #stage 1 hx1 = self.stage1(hx) hx = self.pool12(hx1) #stage 2 hx2 = self.stage2(hx) hx = self.pool23(hx2) #stage 3 hx3 = self.stage3(hx) hx = self.pool34(hx3) #stage 4 hx4 = self.stage4(hx) hx = self.pool45(hx4) #stage 5 hx5 = self.stage5(hx) hx = self.pool56(hx5) #stage 6 hx6 = self.stage6(hx) hx6up = _upsample_like(hx6,hx5) #decoder hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) hx5dup = _upsample_like(hx5d,hx4) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) hx4dup = _upsample_like(hx4d,hx3) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) hx3dup = _upsample_like(hx3d,hx2) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) hx2dup = _upsample_like(hx2d,hx1) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) #side output d1 = self.side1(hx1d) d2 = self.side2(hx2d) d2 = _upsample_like(d2,d1) d3 = self.side3(hx3d) d3 = _upsample_like(d3,d1) d4 = self.side4(hx4d) d4 = _upsample_like(d4,d1) d5 = self.side5(hx5d) d5 = _upsample_like(d5,d1) d6 = self.side6(hx6) d6 = _upsample_like(d6,d1) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) return d0, d1, d2, d3, d4, d5, d6 def u2net(in_channels, num_classes): model = U2NET(in_channels, num_classes) init_weights(model, 'kaiming') return model def u2net_small(in_channels, num_classes): model = U2NETP(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = u2net(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output = model(input) # output = output[1].data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/unet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class conv_block(nn.Module): def __init__(self, ch_in, ch_out): super(conv_block, self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class up_conv(nn.Module): def __init__(self, ch_in, ch_out): super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self, x): x = self.up(x) return x class Recurrent_block(nn.Module): def __init__(self, ch_out, t=2): super(Recurrent_block, self).__init__() self.t = t self.ch_out = ch_out self.conv = nn.Sequential( nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self, x): for i in range(self.t): if i == 0: x1 = self.conv(x) x1 = self.conv(x + x1) return x1 class RRCNN_block(nn.Module): def __init__(self, ch_in, ch_out, t=2): super(RRCNN_block, self).__init__() self.RCNN = nn.Sequential( Recurrent_block(ch_out, t=t), Recurrent_block(ch_out, t=t) ) self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.Conv_1x1(x) x1 = self.RCNN(x) return x + x1 class single_conv(nn.Module): def __init__(self, ch_in, ch_out): super(single_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class Attention_block(nn.Module): def __init__(self, F_g, F_l, F_int): super(Attention_block, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class U_Net(nn.Module): def __init__(self, in_channels=3, num_classes=1): super(U_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Conv1 = conv_block(ch_in=in_channels, ch_out=64) self.Conv2 = conv_block(ch_in=64, ch_out=128) self.Conv3 = conv_block(ch_in=128, ch_out=256) self.Conv4 = conv_block(ch_in=256, ch_out=512) self.Conv5 = conv_block(ch_in=512, ch_out=1024) self.Up5 = up_conv(ch_in=1024, ch_out=512) self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) self.Up4 = up_conv(ch_in=512, ch_out=256) self.Up_conv4 = conv_block(ch_in=512, ch_out=256) self.Up3 = up_conv(ch_in=256, ch_out=128) self.Up_conv3 = conv_block(ch_in=256, ch_out=128) self.Up2 = up_conv(ch_in=128, ch_out=64) self.Up_conv2 = conv_block(ch_in=128, ch_out=64) self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): # encoding path x1 = self.Conv1(x) x2 = self.Maxpool(x1) x2 = self.Conv2(x2) x3 = self.Maxpool(x2) x3 = self.Conv3(x3) x4 = self.Maxpool(x3) x4 = self.Conv4(x4) x5 = self.Maxpool(x4) x5 = self.Conv5(x5) # decoding + concat path d5 = self.Up5(x5) d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_conv5(d5) d4 = self.Up4(d5) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) # outputs = [] # outputs.append(d1) # return outputs return d1 class R2U_Net(nn.Module): def __init__(self, in_channels=3, num_classes=1, t=2): super(R2U_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t) self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) self.Up5 = up_conv(ch_in=1024, ch_out=512) self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) self.Up4 = up_conv(ch_in=512, ch_out=256) self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) self.Up3 = up_conv(ch_in=256, ch_out=128) self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) self.Up2 = up_conv(ch_in=128, ch_out=64) self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): # encoding path x1 = self.RRCNN1(x) x2 = self.Maxpool(x1) x2 = self.RRCNN2(x2) x3 = self.Maxpool(x2) x3 = self.RRCNN3(x3) x4 = self.Maxpool(x3) x4 = self.RRCNN4(x4) x5 = self.Maxpool(x4) x5 = self.RRCNN5(x5) # decoding + concat path d5 = self.Up5(x5) d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_RRCNN5(d5) d4 = self.Up4(d5) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_RRCNN4(d4) d3 = self.Up3(d4) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_RRCNN3(d3) d2 = self.Up2(d3) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_RRCNN2(d2) d1 = self.Conv_1x1(d2) # outputs = [] # outputs.append(d1) # return outputs return d1 class AttU_Net(nn.Module): def __init__(self, in_channels=3, num_classes=1): super(AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Conv1 = conv_block(ch_in=in_channels, ch_out=64) self.Conv2 = conv_block(ch_in=64, ch_out=128) self.Conv3 = conv_block(ch_in=128, ch_out=256) self.Conv4 = conv_block(ch_in=256, ch_out=512) self.Conv5 = conv_block(ch_in=512, ch_out=1024) self.Up5 = up_conv(ch_in=1024, ch_out=512) self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) self.Up4 = up_conv(ch_in=512, ch_out=256) self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) self.Up_conv4 = conv_block(ch_in=512, ch_out=256) self.Up3 = up_conv(ch_in=256, ch_out=128) self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) self.Up_conv3 = conv_block(ch_in=256, ch_out=128) self.Up2 = up_conv(ch_in=128, ch_out=64) self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) self.Up_conv2 = conv_block(ch_in=128, ch_out=64) self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): # encoding path x1 = self.Conv1(x) x2 = self.Maxpool(x1) x2 = self.Conv2(x2) x3 = self.Maxpool(x2) x3 = self.Conv3(x3) x4 = self.Maxpool(x3) x4 = self.Conv4(x4) x5 = self.Maxpool(x4) x5 = self.Conv5(x5) # decoding + concat path d5 = self.Up5(x5) x4 = self.Att5(g=d5, x=x4) d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_conv5(d5) d4 = self.Up4(d5) x3 = self.Att4(g=d4, x=x3) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_conv4(d4) d3 = self.Up3(d4) x2 = self.Att3(g=d3, x=x2) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_conv3(d3) d2 = self.Up2(d3) x1 = self.Att2(g=d2, x=x1) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_conv2(d2) d1 = self.Conv_1x1(d2) # outputs = [] # outputs.append(d1) # return outputs return d1 class R2AttU_Net(nn.Module): def __init__(self, in_channels=3, num_classes=1, t=2): super(R2AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.Upsample = nn.Upsample(scale_factor=2) self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t) self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t) self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t) self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t) self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t) self.Up5 = up_conv(ch_in=1024, ch_out=512) self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256) self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t) self.Up4 = up_conv(ch_in=512, ch_out=256) self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128) self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t) self.Up3 = up_conv(ch_in=256, ch_out=128) self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64) self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t) self.Up2 = up_conv(ch_in=128, ch_out=64) self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32) self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t) self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): # encoding path x1 = self.RRCNN1(x) x2 = self.Maxpool(x1) x2 = self.RRCNN2(x2) x3 = self.Maxpool(x2) x3 = self.RRCNN3(x3) x4 = self.Maxpool(x3) x4 = self.RRCNN4(x4) x5 = self.Maxpool(x4) x5 = self.RRCNN5(x5) # decoding + concat path d5 = self.Up5(x5) x4 = self.Att5(g=d5, x=x4) d5 = torch.cat((x4, d5), dim=1) d5 = self.Up_RRCNN5(d5) d4 = self.Up4(d5) x3 = self.Att4(g=d4, x=x3) d4 = torch.cat((x3, d4), dim=1) d4 = self.Up_RRCNN4(d4) d3 = self.Up3(d4) x2 = self.Att3(g=d3, x=x2) d3 = torch.cat((x2, d3), dim=1) d3 = self.Up_RRCNN3(d3) d2 = self.Up2(d3) x1 = self.Att2(g=d2, x=x1) d2 = torch.cat((x1, d2), dim=1) d2 = self.Up_RRCNN2(d2) d1 = self.Conv_1x1(d2) # outputs = [] # outputs.append(d1) # return outputs return d1 def unet(in_channels, num_classes): model = U_Net(in_channels, num_classes) init_weights(model, 'kaiming') return model def r2_unet(in_channels, num_classes): model = R2U_Net(in_channels, num_classes) init_weights(model, 'kaiming') return model def attention_unet(in_channels, num_classes): model = AttU_Net(in_channels, num_classes) init_weights(model, 'kaiming') return model def r2_attention_unet(in_channels, num_classes): model = R2AttU_Net(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = U_Net(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output = model(input) # output = output[0].data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/unet_3plus.py ================================================ # -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init import numpy as np def weights_init_normal(m): classname = m.__class__.__name__ #print(classname) if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def weights_init_xavier(m): classname = m.__class__.__name__ #print(classname) if classname.find('Conv') != -1: init.xavier_normal_(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.xavier_normal_(m.weight.data, gain=1) elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def weights_init_kaiming(m): classname = m.__class__.__name__ #print(classname) if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ #print(classname) if classname.find('Conv') != -1: init.orthogonal_(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.orthogonal_(m.weight.data, gain=1) elif classname.find('BatchNorm') != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) def init_weights(net, init_type='normal'): #print('initialization method [%s]' % init_type) if init_type == 'normal': net.apply(weights_init_normal) elif init_type == 'xavier': net.apply(weights_init_xavier) elif init_type == 'kaiming': net.apply(weights_init_kaiming) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) class unetConv2(nn.Module): def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(unetConv2, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n + 1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size else: for i in range(1, n + 1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming') def forward(self, inputs): x = inputs for i in range(1, self.n + 1): conv = getattr(self, 'conv%d' % i) x = conv(x) return x ''' UNet 3+ ''' class UNet_3Plus(nn.Module): def __init__(self, in_channels, num_classes): super(UNet_3Plus, self).__init__() feature_scale = 4 is_deconv = True is_batchnorm = True self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [16, 32, 64, 128, 256] ## -------------Encoder-------------- self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) ## -------------Decoder-------------- self.CatChannels = filters[0] self.CatBlocks = 5 self.UpChannels = self.CatChannels * self.CatBlocks '''stage 4d''' # h1->320*320, hd4->40*40, Pooling 8 times self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd4_relu = nn.ReLU(inplace=True) # h2->160*160, hd4->40*40, Pooling 4 times self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd4_relu = nn.ReLU(inplace=True) # h3->80*80, hd4->40*40, Pooling 2 times self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h3_PT_hd4_relu = nn.ReLU(inplace=True) # h4->40*40, hd4->40*40, Concatenation self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->40*40, Upsample 2 times self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) self.relu4d_1 = nn.ReLU(inplace=True) '''stage 3d''' # h1->320*320, hd3->80*80, Pooling 4 times self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd3_relu = nn.ReLU(inplace=True) # h2->160*160, hd3->80*80, Pooling 2 times self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd3_relu = nn.ReLU(inplace=True) # h3->80*80, hd3->80*80, Concatenation self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) # hd4->40*40, hd4->80*80, Upsample 2 times self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->80*80, Upsample 4 times self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) self.relu3d_1 = nn.ReLU(inplace=True) '''stage 2d ''' # h1->320*320, hd2->160*160, Pooling 2 times self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd2_relu = nn.ReLU(inplace=True) # h2->160*160, hd2->160*160, Concatenation self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) # hd3->80*80, hd2->160*160, Upsample 2 times self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) # hd4->40*40, hd2->160*160, Upsample 4 times self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) # hd5->20*20, hd2->160*160, Upsample 8 times self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) self.relu2d_1 = nn.ReLU(inplace=True) '''stage 1d''' # h1->320*320, hd1->320*320, Concatenation self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) # hd2->160*160, hd1->320*320, Upsample 2 times self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) # hd3->80*80, hd1->320*320, Upsample 4 times self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) # hd4->40*40, hd1->320*320, Upsample 8 times self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) # hd5->20*20, hd1->320*320, Upsample 16 times self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) self.relu1d_1 = nn.ReLU(inplace=True) # output self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming') def forward(self, inputs): ## -------------Encoder------------- h1 = self.conv1(inputs) # h1->320*320*64 h2 = self.maxpool1(h1) h2 = self.conv2(h2) # h2->160*160*128 h3 = self.maxpool2(h2) h3 = self.conv3(h3) # h3->80*80*256 h4 = self.maxpool3(h3) h4 = self.conv4(h4) # h4->40*40*512 h5 = self.maxpool4(h4) hd5 = self.conv5(h5) # h5->20*20*1024 ## -------------Decoder------------- h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1( torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1( torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1( torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1( torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels d1 = self.outconv1(hd1) # d1->320*320*n_classes return d1 ''' UNet 3+ with deep supervision ''' class UNet_3Plus_DeepSup(nn.Module): def __init__(self, in_channels=3, num_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True): super(UNet_3Plus_DeepSup, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [32, 64, 128, 256, 512] ## -------------Encoder-------------- self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) ## -------------Decoder-------------- self.CatChannels = filters[0] self.CatBlocks = 5 self.UpChannels = self.CatChannels * self.CatBlocks '''stage 4d''' # h1->320*320, hd4->40*40, Pooling 8 times self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd4_relu = nn.ReLU(inplace=True) # h2->160*160, hd4->40*40, Pooling 4 times self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd4_relu = nn.ReLU(inplace=True) # h3->80*80, hd4->40*40, Pooling 2 times self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h3_PT_hd4_relu = nn.ReLU(inplace=True) # h4->40*40, hd4->40*40, Concatenation self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->40*40, Upsample 2 times self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) self.relu4d_1 = nn.ReLU(inplace=True) '''stage 3d''' # h1->320*320, hd3->80*80, Pooling 4 times self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd3_relu = nn.ReLU(inplace=True) # h2->160*160, hd3->80*80, Pooling 2 times self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd3_relu = nn.ReLU(inplace=True) # h3->80*80, hd3->80*80, Concatenation self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) # hd4->40*40, hd4->80*80, Upsample 2 times self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->80*80, Upsample 4 times self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) self.relu3d_1 = nn.ReLU(inplace=True) '''stage 2d ''' # h1->320*320, hd2->160*160, Pooling 2 times self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd2_relu = nn.ReLU(inplace=True) # h2->160*160, hd2->160*160, Concatenation self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) # hd3->80*80, hd2->160*160, Upsample 2 times self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) # hd4->40*40, hd2->160*160, Upsample 4 times self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) # hd5->20*20, hd2->160*160, Upsample 8 times self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) self.relu2d_1 = nn.ReLU(inplace=True) '''stage 1d''' # h1->320*320, hd1->320*320, Concatenation self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) # hd2->160*160, hd1->320*320, Upsample 2 times self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) # hd3->80*80, hd1->320*320, Upsample 4 times self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) # hd4->40*40, hd1->320*320, Upsample 8 times self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) # hd5->20*20, hd1->320*320, Upsample 16 times self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) self.relu1d_1 = nn.ReLU(inplace=True) # -------------Bilinear Upsampling-------------- self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) ### self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # DeepSup self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1) self.outconv2 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1) self.outconv3 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1) self.outconv4 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1) self.outconv5 = nn.Conv2d(filters[4], num_classes, 3, padding=1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming') def forward(self, inputs): ## -------------Encoder------------- h1 = self.conv1(inputs) # h1->320*320*64 h2 = self.maxpool1(h1) h2 = self.conv2(h2) # h2->160*160*128 h3 = self.maxpool2(h2) h3 = self.conv3(h3) # h3->80*80*256 h4 = self.maxpool3(h3) h4 = self.conv4(h4) # h4->40*40*512 h5 = self.maxpool4(h4) hd5 = self.conv5(h5) # h5->20*20*1024 ## -------------Decoder------------- h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1( torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1( torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1( torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1( torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels d5 = self.outconv5(hd5) d5 = self.upscore5(d5) # 16->256 d4 = self.outconv4(hd4) d4 = self.upscore4(d4) # 32->256 d3 = self.outconv3(hd3) d3 = self.upscore3(d3) # 64->256 d2 = self.outconv2(hd2) d2 = self.upscore2(d2) # 128->256 d1 = self.outconv1(hd1) # 256 return d1, d2, d3, d4, d5 ''' UNet 3+ with deep supervision and class-guided module ''' class UNet_3Plus_DeepSup_CGM(nn.Module): def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True): super(UNet_3Plus_DeepSup_CGM, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] ## -------------Encoder-------------- self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm) ## -------------Decoder-------------- self.CatChannels = filters[0] self.CatBlocks = 5 self.UpChannels = self.CatChannels * self.CatBlocks '''stage 4d''' # h1->320*320, hd4->40*40, Pooling 8 times self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd4_relu = nn.ReLU(inplace=True) # h2->160*160, hd4->40*40, Pooling 4 times self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd4_relu = nn.ReLU(inplace=True) # h3->80*80, hd4->40*40, Pooling 2 times self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h3_PT_hd4_relu = nn.ReLU(inplace=True) # h4->40*40, hd4->40*40, Concatenation self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->40*40, Upsample 2 times self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) self.relu4d_1 = nn.ReLU(inplace=True) '''stage 3d''' # h1->320*320, hd3->80*80, Pooling 4 times self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd3_relu = nn.ReLU(inplace=True) # h2->160*160, hd3->80*80, Pooling 2 times self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd3_relu = nn.ReLU(inplace=True) # h3->80*80, hd3->80*80, Concatenation self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) # hd4->40*40, hd4->80*80, Upsample 2 times self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->80*80, Upsample 4 times self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) self.relu3d_1 = nn.ReLU(inplace=True) '''stage 2d ''' # h1->320*320, hd2->160*160, Pooling 2 times self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd2_relu = nn.ReLU(inplace=True) # h2->160*160, hd2->160*160, Concatenation self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) # hd3->80*80, hd2->160*160, Upsample 2 times self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) # hd4->40*40, hd2->160*160, Upsample 4 times self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) # hd5->20*20, hd2->160*160, Upsample 8 times self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) self.relu2d_1 = nn.ReLU(inplace=True) '''stage 1d''' # h1->320*320, hd1->320*320, Concatenation self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) # hd2->160*160, hd1->320*320, Upsample 2 times self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 14*14 self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) # hd3->80*80, hd1->320*320, Upsample 4 times self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) # 14*14 self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) # hd4->40*40, hd1->320*320, Upsample 8 times self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) # 14*14 self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) # hd5->20*20, hd1->320*320, Upsample 16 times self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) # 14*14 self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) self.relu1d_1 = nn.ReLU(inplace=True) # -------------Bilinear Upsampling-------------- self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) ### self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # DeepSup self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1) self.cls = nn.Sequential( nn.Dropout(p=0.5), nn.Conv2d(filters[4], 2, 1), nn.AdaptiveMaxPool2d(1), nn.Sigmoid()) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming') def dotProduct(self, seg, cls): B, N, H, W = seg.size() seg = seg.view(B, N, H * W) final = torch.einsum("ijk,ij->ijk", [seg, cls]) final = final.view(B, N, H, W) return final def forward(self, inputs): ## -------------Encoder------------- h1 = self.conv1(inputs) # h1->320*320*64 h2 = self.maxpool1(h1) h2 = self.conv2(h2) # h2->160*160*128 h3 = self.maxpool2(h2) h3 = self.conv3(h3) # h3->80*80*256 h4 = self.maxpool3(h3) h4 = self.conv4(h4) # h4->40*40*512 h5 = self.maxpool4(h4) hd5 = self.conv5(h5) # h5->20*20*1024 # -------------Classification------------- cls_branch = self.cls(hd5).squeeze(3).squeeze(2) # (B,N,1,1)->(B,N) cls_branch_max = cls_branch.argmax(dim=1) cls_branch_max = cls_branch_max[:, np.newaxis].float() ## -------------Decoder------------- h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels d5 = self.outconv5(hd5) d5 = self.upscore5(d5) # 16->256 d4 = self.outconv4(hd4) d4 = self.upscore4(d4) # 32->256 d3 = self.outconv3(hd3) d3 = self.upscore3(d3) # 64->256 d2 = self.outconv2(hd2) d2 = self.upscore2(d2) # 128->256 d1 = self.outconv1(hd1) # 256 d1 = self.dotProduct(d1, cls_branch_max) d2 = self.dotProduct(d2, cls_branch_max) d3 = self.dotProduct(d3, cls_branch_max) d4 = self.dotProduct(d4, cls_branch_max) d5 = self.dotProduct(d5, cls_branch_max) return d1, d2, d3, d4, d5, def unet_3plus(in_channels, num_classes): model = UNet_3Plus(in_channels, num_classes) return model def unet_3plus_ds(in_channels, num_classes): model = UNet_3Plus_DeepSup(in_channels, num_classes) return model def unet_3plus_ds_cgm(in_channels, num_classes): model = UNet_3Plus_DeepSup_CGM(in_channels, num_classes) return model if __name__ == '__main__': model = unet_3plus_ds_cgm(1,10) model.eval() input = torch.rand(2,1,128,128) output = model(input) output = output[0].data.cpu().numpy() # print(output) print(output.shape) ================================================ FILE: models/networks_2d/unet_cct.py ================================================ import numpy as np import torch import torch.nn as nn from torch.distributions.uniform import Uniform from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class ConvBlock(nn.Module): """two convolution layers with batch norm and leaky relu""" def __init__(self, in_channels, out_channels, dropout_p): super(ConvBlock, self).__init__() self.conv_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU(), nn.Dropout(dropout_p), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU() ) def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): """Downsampling followed by ConvBlock""" def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), ConvBlock(in_channels, out_channels, dropout_p) ) def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): """Upssampling followed by ConvBlock""" def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, bilinear=True): super(UpBlock, self).__init__() self.bilinear = bilinear if bilinear: self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) self.up = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d( in_channels1, in_channels2, kernel_size=2, stride=2) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): if self.bilinear: x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) return self.conv(x) class Encoder(nn.Module): def __init__(self, params): super(Encoder, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] self.dropout = self.params['dropout'] assert (len(self.ft_chns) == 5) self.in_conv = ConvBlock( self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock( self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock( self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock( self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock( self.ft_chns[3], self.ft_chns[4], self.dropout[4]) def forward(self, x): x0 = self.in_conv(x) x1 = self.down1(x0) x2 = self.down2(x1) x3 = self.down3(x2) x4 = self.down4(x3) return [x0, x1, x2, x3, x4] class Decoder(nn.Module): def __init__(self, params): super(Decoder, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] assert (len(self.ft_chns) == 5) self.up1 = UpBlock( self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) self.up2 = UpBlock( self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) self.up3 = UpBlock( self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) self.up4 = UpBlock( self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) def forward(self, feature): x0 = feature[0] x1 = feature[1] x2 = feature[2] x3 = feature[3] x4 = feature[4] x = self.up1(x4, x3) x = self.up2(x, x2) x = self.up3(x, x1) x = self.up4(x, x0) output = self.out_conv(x) return output def Dropout(x, p=0.3): x = torch.nn.functional.dropout(x, p) return x def FeatureDropout(x): attention = torch.mean(x, dim=1, keepdim=True) max_val, _ = torch.max(attention.view( x.size(0), -1), dim=1, keepdim=True) threshold = max_val * np.random.uniform(0.7, 0.9) threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) drop_mask = (attention < threshold).float() x = x.mul(drop_mask) return x class FeatureNoise(nn.Module): def __init__(self, uniform_range=0.3): super(FeatureNoise, self).__init__() self.uni_dist = Uniform(-uniform_range, uniform_range) def feature_based_noise(self, x): noise_vector = self.uni_dist.sample( x.shape[1:]).to(x.device).unsqueeze(0) x_noise = x.mul(noise_vector) + x return x_noise def forward(self, x): x = self.feature_based_noise(x) return x class UNet_CCT(nn.Module): def __init__(self, in_chns, class_num): super(UNet_CCT, self).__init__() params = {'in_chns': in_chns, 'feature_chns': [16, 32, 64, 128, 256], 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 'class_num': class_num, 'bilinear': False, 'acti_func': 'relu'} self.encoder = Encoder(params) self.main_decoder = Decoder(params) self.aux_decoder1 = Decoder(params) self.aux_decoder2 = Decoder(params) self.aux_decoder3 = Decoder(params) def forward(self, x): feature = self.encoder(x) main_seg = self.main_decoder(feature) aux1_feature = [FeatureNoise()(i) for i in feature] aux_seg1 = self.aux_decoder1(aux1_feature) aux2_feature = [Dropout(i) for i in feature] aux_seg2 = self.aux_decoder2(aux2_feature) aux3_feature = [FeatureDropout(i) for i in feature] aux_seg3 = self.aux_decoder3(aux3_feature) return main_seg, aux_seg1, aux_seg2, aux_seg3 def unet_cct(in_channels, num_classes): model = UNet_CCT(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet_cct(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output, output1, output2, output3 = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/unet_plusplus.py ================================================ import torch from torch import nn from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class VGGBlock(nn.Module): def __init__(self, in_channels, middle_channels, out_channels): super().__init__() self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(middle_channels) self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) return out class NestedUNet(nn.Module): def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs): super().__init__() nb_filter = [32, 64, 128, 256, 512] self.deep_supervision = deep_supervision self.pool = nn.MaxPool2d(2, 2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0]) self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1]) self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2]) self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0]) self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1]) self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0]) if self.deep_supervision: self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) else: self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) def forward(self, input): x0_0 = self.conv0_0(input) x1_0 = self.conv1_0(self.pool(x0_0)) x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) x2_0 = self.conv2_0(self.pool(x1_0)) x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) x3_0 = self.conv3_0(self.pool(x2_0)) x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) x4_0 = self.conv4_0(self.pool(x3_0)) x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) outputs = [] if self.deep_supervision: output1 = self.final1(x0_1) output2 = self.final2(x0_2) output3 = self.final3(x0_3) output4 = self.final4(x0_4) outputs.append(output1) outputs.append(output2) outputs.append(output3) outputs.append(output4) return outputs else: output = self.final(x0_4) # outputs.append(output) # return outputs return output def unet_plusplus(in_channels, num_classes): model = NestedUNet(num_classes=num_classes, input_channels=in_channels) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet_plusplus(3,10, True) # model.eval() # input = torch.rand(1,3,128,128) # output = model(input) # output = output.data.cpu().numpy() # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/unet_urpc.py ================================================ from __future__ import division, print_function import numpy as np import torch import torch.nn as nn from torch.distributions.uniform import Uniform from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class ConvBlock(nn.Module): """two convolution layers with batch norm and leaky relu""" def __init__(self, in_channels, out_channels, dropout_p): super(ConvBlock, self).__init__() self.conv_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU(), nn.Dropout(dropout_p), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU() ) def forward(self, x): return self.conv_conv(x) class Encoder(nn.Module): def __init__(self, params): super(Encoder, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] self.dropout = self.params['dropout'] assert (len(self.ft_chns) == 5) self.in_conv = ConvBlock( self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock( self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock( self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock( self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock( self.ft_chns[3], self.ft_chns[4], self.dropout[4]) def forward(self, x): x0 = self.in_conv(x) x1 = self.down1(x0) x2 = self.down2(x1) x3 = self.down3(x2) x4 = self.down4(x3) return [x0, x1, x2, x3, x4] class DownBlock(nn.Module): """Downsampling followed by ConvBlock""" def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), ConvBlock(in_channels, out_channels, dropout_p) ) def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): """Upssampling followed by ConvBlock""" def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, bilinear=True): super(UpBlock, self).__init__() self.bilinear = bilinear if bilinear: self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) self.up = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d( in_channels1, in_channels2, kernel_size=2, stride=2) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): if self.bilinear: x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) return self.conv(x) class FeatureNoise(nn.Module): def __init__(self, uniform_range=0.3): super(FeatureNoise, self).__init__() self.uni_dist = Uniform(-uniform_range, uniform_range) def feature_based_noise(self, x): noise_vector = self.uni_dist.sample( x.shape[1:]).to(x.device).unsqueeze(0) x_noise = x.mul(noise_vector) + x return x_noise def forward(self, x): x = self.feature_based_noise(x) return x def Dropout(x, p=0.3): x = torch.nn.functional.dropout(x, p) return x def FeatureDropout(x): attention = torch.mean(x, dim=1, keepdim=True) max_val, _ = torch.max(attention.view( x.size(0), -1), dim=1, keepdim=True) threshold = max_val * np.random.uniform(0.7, 0.9) threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) drop_mask = (attention < threshold).float() x = x.mul(drop_mask) return x class Decoder_URPC(nn.Module): def __init__(self, params): super(Decoder_URPC, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] assert (len(self.ft_chns) == 5) self.up1 = UpBlock( self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) self.up2 = UpBlock( self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) self.up3 = UpBlock( self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) self.up4 = UpBlock( self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) # self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, # kernel_size=3, padding=1) self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size=3, padding=1) self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size=3, padding=1) self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size=3, padding=1) # self.feature_noise = FeatureNoise() def forward(self, feature, shape): x0 = feature[0] x1 = feature[1] x2 = feature[2] x3 = feature[3] x4 = feature[4] x = self.up1(x4, x3) if self.training: # dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5)) dp3_out_seg = self.out_conv_dp3(x) else: dp3_out_seg = self.out_conv_dp3(x) dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape) x = self.up2(x, x2) if self.training: # dp2_out_seg = self.out_conv_dp2(FeatureDropout(x)) dp2_out_seg = self.out_conv_dp2(x) else: dp2_out_seg = self.out_conv_dp2(x) dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape) x = self.up3(x, x1) if self.training: # dp1_out_seg = self.out_conv_dp1(self.feature_noise(x)) dp1_out_seg = self.out_conv_dp1(x) else: dp1_out_seg = self.out_conv_dp1(x) dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape) x = self.up4(x, x0) dp0_out_seg = self.out_conv(x) return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg class UNet_URPC(nn.Module): def __init__(self, in_chns, class_num): super(UNet_URPC, self).__init__() params = {'in_chns': in_chns, 'feature_chns': [16, 32, 64, 128, 256], 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 'class_num': class_num, 'bilinear': False, 'acti_func': 'relu'} self.encoder = Encoder(params) self.decoder = Decoder_URPC(params) def forward(self, x): shape = x.shape[2:] feature = self.encoder(x) dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg = self.decoder( feature, shape) return dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg def unet_urpc(in_channels, num_classes): model = UNet_URPC(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet_urpc(1,10) # model.eval() # input = torch.rand(2,1,128,128) # output, output1, output2, output3 = model(input) # output = output1.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_2d/wavesnet.py ================================================ import numpy as np import math, pywt import torch import torch.nn as nn from torch.nn import Module from torch.autograd import Function from collections import OrderedDict from itertools import islice import operator class My_DownSampling_SC(nn.Module): def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)): super(My_DownSampling_SC, self).__init__() self.conv = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding) def forward(self, input): return self.conv(input), input class My_DownSampling_MP(nn.Module): def __init__(self, stride = 2, kernel_size = 2): super(My_DownSampling_MP, self).__init__() self.maxp = nn.MaxPool2d(kernel_size = kernel_size, stride = stride, return_indices = False) def forward(self, input): return self.maxp(input), input class My_UpSampling_SC(nn.Module): def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)): super(My_UpSampling_SC, self).__init__() self.conv = nn.ConvTranspose2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding) def forward(self, input, feature_map): return torch.cat((self.conv(input), feature_map), dim = 1) class My_DownSampling_DWT(nn.Module): def __init__(self, wavename = 'haar'): super(My_DownSampling_DWT, self).__init__() self.dwt = DWT_2D(wavename = wavename) def forward(self, input): LL, LH, HL, HH = self.dwt(input) return LL, LH, HL, HH, input class My_UpSampling_IDWT(nn.Module): def __init__(self, wavename = 'haar'): super(My_UpSampling_IDWT, self).__init__() self.idwt = IDWT_2D(wavename = wavename) def forward(self, LL, LH, HL, HH, feature_map): return torch.cat((self.idwt(LL, LH, HL, HH), feature_map), dim = 1) class My_Sequential(Module): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. 若某个模块输出多个数据,只将第一个数据往下传 """ def __init__(self, *args): super(My_Sequential, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) def _get_item_by_idx(self, iterator, idx): """Get the idx-th item of the iterator""" size = len(self) idx = operator.index(idx) if not -size <= idx < size: raise IndexError('index {} is out of range'.format(idx)) idx %= size return next(islice(iterator, idx, None)) def __getitem__(self, idx): if isinstance(idx, slice): return self.__class__(OrderedDict(list(self._modules.items())[idx])) else: return self._get_item_by_idx(self._modules.values(), idx) def __setitem__(self, idx, module): key = self._get_item_by_idx(self._modules.keys(), idx) return setattr(self, key, module) def __delitem__(self, idx): if isinstance(idx, slice): for key in list(self._modules.keys())[idx]: delattr(self, key) else: key = self._get_item_by_idx(self._modules.keys(), idx) delattr(self, key) def __len__(self): return len(self._modules) def __dir__(self): keys = super(My_Sequential, self).__dir__() keys = [key for key in keys if not key.isdigit()] return keys def forward(self, input): self.output = [] for module in self._modules.values(): input = module(input) if isinstance(input, tuple): assert len(input) == 4 or len(input) == 2 or len(input) == 5 self.output.append(input[1:]) input = input[0] if self.output != []: return input, self.output else: return input class My_Sequential_re(Module): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. 若某个模块输出多个数据,只将第一个数据往下传 """ def __init__(self, *args): super(My_Sequential_re, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) self.output = [] def _get_item_by_idx(self, iterator, idx): """Get the idx-th item of the iterator""" size = len(self) idx = operator.index(idx) if not -size <= idx < size: raise IndexError('index {} is out of range'.format(idx)) idx %= size return next(islice(iterator, idx, None)) def __getitem__(self, idx): if isinstance(idx, slice): return self.__class__(OrderedDict(list(self._modules.items())[idx])) else: return self._get_item_by_idx(self._modules.values(), idx) def __setitem__(self, idx, module): key = self._get_item_by_idx(self._modules.keys(), idx) return setattr(self, key, module) def __delitem__(self, idx): if isinstance(idx, slice): for key in list(self._modules.keys())[idx]: delattr(self, key) else: key = self._get_item_by_idx(self._modules.keys(), idx) delattr(self, key) def __len__(self): return len(self._modules) def __dir__(self): keys = super(My_Sequential_re, self).__dir__() keys = [key for key in keys if not key.isdigit()] return keys def forward(self, *input): LL = input[0] index = 1 for module in self._modules.values(): if isinstance(module, My_UpSampling_IDWT): LH = input[index] HL = input[index + 1] HH = input[index + 2] feature_map = input[index + 3] LL = module(LL, LH, HL, HH, feature_map = feature_map) index += 4 elif isinstance(module, IDWT_2D) or 'idwt' in dir(module): LH = input[index] HL = input[index + 1] HH = input[index + 2] LL = module(LL, LH, HL, HH) index += 3 elif isinstance(module, nn.MaxUnpool2d): indices = input[index] LL = module(input = LL, indices = indices) #_, _, h, w = LL.size() #LL = F.interpolate(LL, size = (2*h, 2*w), mode = 'bilinear', align_corners = True) index += 1 elif isinstance(module, My_UpSampling_SC): feature_map = input[index] LL = module(input = LL, feature_map = feature_map) index += 1 else: LL = module(LL) return LL class DWTFunction_1D(Function): @staticmethod def forward(ctx, input, matrix_Low, matrix_High): ctx.save_for_backward(matrix_Low, matrix_High) L = torch.matmul(input, matrix_Low.t()) H = torch.matmul(input, matrix_High.t()) return L, H @staticmethod def backward(ctx, grad_L, grad_H): matrix_L, matrix_H = ctx.saved_variables grad_input = torch.add(torch.matmul(grad_L, matrix_L), torch.matmul(grad_H, matrix_H)) return grad_input, None, None class IDWTFunction_1D(Function): @staticmethod def forward(ctx, input_L, input_H, matrix_L, matrix_H): ctx.save_for_backward(matrix_L, matrix_H) output = torch.add(torch.matmul(input_L, matrix_L), torch.matmul(input_H, matrix_H)) return output @staticmethod def backward(ctx, grad_output): matrix_L, matrix_H = ctx.saved_variables grad_L = torch.matmul(grad_output, matrix_L.t()) grad_H = torch.matmul(grad_output, matrix_H.t()) return grad_L, grad_H, None, None class DWTFunction_2D(Function): @staticmethod def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1): ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1) L = torch.matmul(matrix_Low_0, input) H = torch.matmul(matrix_High_0, input) LL = torch.matmul(L, matrix_Low_1) LH = torch.matmul(L, matrix_High_1) HL = torch.matmul(H, matrix_Low_1) HH = torch.matmul(H, matrix_High_1) return LL, LH, HL, HH @staticmethod def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH): matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t())) grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t())) grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H)) return grad_input, None, None, None, None class DWTFunction_2D_tiny(Function): @staticmethod def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1): ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1) L = torch.matmul(matrix_Low_0, input) LL = torch.matmul(L, matrix_Low_1) return LL @staticmethod def backward(ctx, grad_LL): matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables grad_L = torch.matmul(grad_LL, matrix_Low_1.t()) grad_input = torch.matmul(matrix_Low_0.t(), grad_L) return grad_input, None, None, None, None class IDWTFunction_2D(Function): @staticmethod def forward(ctx, input_LL, input_LH, input_HL, input_HH, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1): ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1) L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t())) H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t())) output = torch.add(torch.matmul(matrix_Low_0.t(), L), torch.matmul(matrix_High_0.t(), H)) return output @staticmethod def backward(ctx, grad_output): matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors grad_L = torch.matmul(matrix_Low_0, grad_output) grad_H = torch.matmul(matrix_High_0, grad_output) grad_LL = torch.matmul(grad_L, matrix_Low_1) grad_LH = torch.matmul(grad_L, matrix_High_1) grad_HL = torch.matmul(grad_H, matrix_Low_1) grad_HH = torch.matmul(grad_H, matrix_High_1) return grad_LL, grad_LH, grad_HL, grad_HH, None, None, None, None class DWTFunction_3D(Function): @staticmethod def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2): ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2) L = torch.matmul(matrix_Low_0, input) H = torch.matmul(matrix_High_0, input) LL = torch.matmul(L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3) LH = torch.matmul(L, matrix_High_1).transpose(dim0 = 2, dim1 = 3) HL = torch.matmul(H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3) HH = torch.matmul(H, matrix_High_1).transpose(dim0 = 2, dim1 = 3) LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0 = 2, dim1 = 3) LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0 = 2, dim1 = 3) LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0 = 2, dim1 = 3) LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0 = 2, dim1 = 3) HLL = torch.matmul(matrix_High_2, LL).transpose(dim0 = 2, dim1 = 3) HLH = torch.matmul(matrix_High_2, LH).transpose(dim0 = 2, dim1 = 3) HHL = torch.matmul(matrix_High_2, HL).transpose(dim0 = 2, dim1 = 3) HHH = torch.matmul(matrix_High_2, HH).transpose(dim0 = 2, dim1 = 3) return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH @staticmethod def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH): matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t())) grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t())) grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H)) return grad_input, None, None, None, None, None, None, None, None class IDWTFunction_3D(Function): @staticmethod def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH, input_HLL, input_HLH, input_HHL, input_HHH, matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2): ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2) input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3) input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t())) input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t())) output = torch.add(torch.matmul(matrix_Low_0.t(), input_L), torch.matmul(matrix_High_0.t(), input_H)) return output @staticmethod def backward(ctx, grad_output): matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables grad_L = torch.matmul(matrix_Low_0, grad_output) grad_H = torch.matmul(matrix_High_0, grad_output) grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3) grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0 = 2, dim1 = 3) grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3) grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0 = 2, dim1 = 3) grad_LLL = torch.matmul(matrix_Low_2, grad_LL).transpose(dim0 = 2, dim1 = 3) grad_LLH = torch.matmul(matrix_Low_2, grad_LH).transpose(dim0 = 2, dim1 = 3) grad_LHL = torch.matmul(matrix_Low_2, grad_HL).transpose(dim0 = 2, dim1 = 3) grad_LHH = torch.matmul(matrix_Low_2, grad_HH).transpose(dim0 = 2, dim1 = 3) grad_HLL = torch.matmul(matrix_High_2, grad_LL).transpose(dim0 = 2, dim1 = 3) grad_HLH = torch.matmul(matrix_High_2, grad_LH).transpose(dim0 = 2, dim1 = 3) grad_HHL = torch.matmul(matrix_High_2, grad_HL).transpose(dim0 = 2, dim1 = 3) grad_HHH = torch.matmul(matrix_High_2, grad_HH).transpose(dim0 = 2, dim1 = 3) return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None class DWT_1D(Module): """ input: (N, C, L) output: L -- (N, C, L/2) H -- (N, C, L/2) """ def __init__(self, wavename): """ :param band_low: 小波分解所用低频滤波器组 :param band_high: 小波分解所用高频滤波器组 """ super(DWT_1D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.rec_lo self.band_high = wavelet.rec_hi assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = self.input_height L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_h = matrix_h[:,(self.band_length_half-1):end] matrix_g = matrix_g[:,(self.band_length_half-1):end] if torch.cuda.is_available(): self.matrix_low = torch.tensor(matrix_h).cuda() self.matrix_high = torch.tensor(matrix_g).cuda() else: self.matrix_low = torch.tensor(matrix_h) self.matrix_high = torch.tensor(matrix_g) def forward(self, input): assert len(input.size()) == 3 self.input_height = input.size()[-1] #assert self.input_height > self.band_length self.get_matrix() return DWTFunction_1D.apply(input, self.matrix_low, self.matrix_high) class IDWT_1D(Module): """ input: L -- (N, C, L/2) H -- (N, C, L/2) output: (N, C, L) """ def __init__(self, wavename): """ :param band_low: 小波重建所需低频滤波器组 :param band_high: 小波重建所需高频滤波器组 """ super(IDWT_1D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.dec_lo self.band_high = wavelet.dec_hi self.band_low.reverse() self.band_high.reverse() assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = self.input_height L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_h = matrix_h[:,(self.band_length_half-1):end] matrix_g = matrix_g[:,(self.band_length_half-1):end] if torch.cuda.is_available(): self.matrix_low = torch.tensor(matrix_h).cuda() self.matrix_high = torch.tensor(matrix_g).cuda() else: self.matrix_low = torch.tensor(matrix_h) self.matrix_high = torch.tensor(matrix_g) def forward(self, L, H): assert len(L.size()) == len(H.size()) == 3 self.input_height = L.size()[-1] + H.size()[-1] #assert self.input_height > self.band_length self.get_matrix() return IDWTFunction_1D.apply(L, H, self.matrix_low, self.matrix_high) class DWT_2D(Module): """ input: (N, C, H, W) output -- LL: (N, C, H/2, W/2) LH: (N, C, H/2, W/2) HL: (N, C, H/2, W/2) HH: (N, C, H/2, W/2) """ def __init__(self, wavename): """ :param band_low: 小波分解所用低频滤波器组 :param band_high: 小波分解所用高频滤波器组 """ super(DWT_2D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.rec_lo self.band_high = wavelet.rec_hi assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = np.max((self.input_height, self.input_width)) L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)] matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)] index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)] matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)] matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end] matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end] matrix_h_1 = np.transpose(matrix_h_1) matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end] matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end] matrix_g_1 = np.transpose(matrix_g_1) if torch.cuda.is_available(): self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda() self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda() self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda() self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda() else: self.matrix_low_0 = torch.Tensor(matrix_h_0) self.matrix_low_1 = torch.Tensor(matrix_h_1) self.matrix_high_0 = torch.Tensor(matrix_g_0) self.matrix_high_1 = torch.Tensor(matrix_g_1) def forward(self, input): assert isinstance(input, torch.Tensor) assert len(input.size()) == 4 self.input_height = input.size()[-2] self.input_width = input.size()[-1] #assert self.input_height > self.band_length and self.input_width > self.band_length self.get_matrix() return DWTFunction_2D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1) class DWT_2D_tiny(Module): """ input: (N, C, H, W) output -- LL: (N, C, H/2, W/2) """ def __init__(self, wavename): """ :param band_low: 小波分解所用低频滤波器组 :param band_high: 小波分解所用高频滤波器组 """ super(DWT_2D_tiny, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.rec_lo self.band_high = wavelet.rec_hi assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = np.max((self.input_height, self.input_width)) L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)] matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)] index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)] matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)] matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end] matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end] matrix_h_1 = np.transpose(matrix_h_1) matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end] matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end] matrix_g_1 = np.transpose(matrix_g_1) if torch.cuda.is_available(): self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda() self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda() self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda() self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda() else: self.matrix_low_0 = torch.Tensor(matrix_h_0) self.matrix_low_1 = torch.Tensor(matrix_h_1) self.matrix_high_0 = torch.Tensor(matrix_g_0) self.matrix_high_1 = torch.Tensor(matrix_g_1) def forward(self, input): assert isinstance(input, torch.Tensor) assert len(input.size()) == 4 self.input_height = input.size()[-2] self.input_width = input.size()[-1] self.get_matrix() return DWTFunction_2D_tiny.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1) class IDWT_2D(Module): """ input -- LL: (N, C, H/2, W/2) LH: (N, C, H/2, W/2) HL: (N, C, H/2, W/2) HH: (N, C, H/2, W/2) output: (N, C, H, W) """ def __init__(self, wavename): """ :param band_low: 小波重建所需低频滤波器组 :param band_high: 小波重建所需高频滤波器组 """ super(IDWT_2D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.dec_lo self.band_low.reverse() self.band_high = wavelet.dec_hi self.band_high.reverse() assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = np.max((self.input_height, self.input_width)) L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)] matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)] index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)] matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)] matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end] matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end] matrix_h_1 = np.transpose(matrix_h_1) matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end] matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end] matrix_g_1 = np.transpose(matrix_g_1) if torch.cuda.is_available(): self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda() self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda() self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda() self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda() else: self.matrix_low_0 = torch.Tensor(matrix_h_0) self.matrix_low_1 = torch.Tensor(matrix_h_1) self.matrix_high_0 = torch.Tensor(matrix_g_0) self.matrix_high_1 = torch.Tensor(matrix_g_1) def forward(self, LL, LH, HL, HH): assert len(LL.size()) == len(LH.size()) == len(HL.size()) == len(HH.size()) == 4 self.input_height = LL.size()[-2] + HH.size()[-2] self.input_width = LL.size()[-1] + HH.size()[-1] #assert self.input_height > self.band_length and self.input_width > self.band_length self.get_matrix() return IDWTFunction_2D.apply(LL, LH, HL, HH, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1) class DWT_3D(Module): """ input: (N, C, D, H, W) output: -- LLL (N, C, D/2, H/2, W/2) -- LLH (N, C, D/2, H/2, W/2) -- LHL (N, C, D/2, H/2, W/2) -- LHH (N, C, D/2, H/2, W/2) -- HLL (N, C, D/2, H/2, W/2) -- HLH (N, C, D/2, H/2, W/2) -- HHL (N, C, D/2, H/2, W/2) -- HHH (N, C, D/2, H/2, W/2) """ def __init__(self, wavename): """ :param band_low: 小波分解所用低频滤波器组 :param band_high: 小波分解所用高频滤波器组 """ super(DWT_3D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.rec_lo self.band_high = wavelet.rec_hi assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = np.max((self.input_height, self.input_width)) L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)] matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)] matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)] index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)] matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)] matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)] matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end] matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end] matrix_h_1 = np.transpose(matrix_h_1) matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end] matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end] matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end] matrix_g_1 = np.transpose(matrix_g_1) matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end] if torch.cuda.is_available(): self.matrix_low_0 = torch.tensor(matrix_h_0).cuda() self.matrix_low_1 = torch.tensor(matrix_h_1).cuda() self.matrix_low_2 = torch.tensor(matrix_h_2).cuda() self.matrix_high_0 = torch.tensor(matrix_g_0).cuda() self.matrix_high_1 = torch.tensor(matrix_g_1).cuda() self.matrix_high_2 = torch.tensor(matrix_g_2).cuda() else: self.matrix_low_0 = torch.tensor(matrix_h_0) self.matrix_low_1 = torch.tensor(matrix_h_1) self.matrix_low_2 = torch.tensor(matrix_h_2) self.matrix_high_0 = torch.tensor(matrix_g_0) self.matrix_high_1 = torch.tensor(matrix_g_1) self.matrix_high_2 = torch.tensor(matrix_g_2) def forward(self, input): assert len(input.size()) == 5 self.input_depth = input.size()[-3] self.input_height = input.size()[-2] self.input_width = input.size()[-1] #assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length self.get_matrix() return DWTFunction_3D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2, self.matrix_high_0, self.matrix_high_1, self.matrix_high_2) class IDWT_3D(Module): """ input: -- LLL (N, C, D/2, H/2, W/2) -- LLH (N, C, D/2, H/2, W/2) -- LHL (N, C, D/2, H/2, W/2) -- LHH (N, C, D/2, H/2, W/2) -- HLL (N, C, D/2, H/2, W/2) -- HLH (N, C, D/2, H/2, W/2) -- HHL (N, C, D/2, H/2, W/2) -- HHH (N, C, D/2, H/2, W/2) output: (N, C, D, H, W) """ def __init__(self, wavename): """ :param band_low: 小波重构所用低频滤波器组 :param band_high: 小波重构所用高频滤波器组 """ super(IDWT_3D, self).__init__() wavelet = pywt.Wavelet(wavename) self.band_low = wavelet.dec_lo self.band_high = wavelet.dec_hi self.band_low.reverse() self.band_high.reverse() assert len(self.band_low) == len(self.band_high) self.band_length = len(self.band_low) assert self.band_length % 2 == 0 self.band_length_half = math.floor(self.band_length / 2) def get_matrix(self): """ 生成变换矩阵 :return: """ L1 = np.max((self.input_height, self.input_width)) L = math.floor(L1 / 2) matrix_h = np.zeros( ( L, L1 + self.band_length - 2 ) ) matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) ) end = None if self.band_length_half == 1 else (-self.band_length_half+1) index = 0 for i in range(L): for j in range(self.band_length): matrix_h[i, index+j] = self.band_low[j] index += 2 matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)] matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)] matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)] index = 0 for i in range(L1 - L): for j in range(self.band_length): matrix_g[i, index+j] = self.band_high[j] index += 2 matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)] matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)] matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)] matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end] matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end] matrix_h_1 = np.transpose(matrix_h_1) matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end] matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end] matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end] matrix_g_1 = np.transpose(matrix_g_1) matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end] if torch.cuda.is_available(): self.matrix_low_0 = torch.tensor(matrix_h_0).cuda() self.matrix_low_1 = torch.tensor(matrix_h_1).cuda() self.matrix_low_2 = torch.tensor(matrix_h_2).cuda() self.matrix_high_0 = torch.tensor(matrix_g_0).cuda() self.matrix_high_1 = torch.tensor(matrix_g_1).cuda() self.matrix_high_2 = torch.tensor(matrix_g_2).cuda() else: self.matrix_low_0 = torch.tensor(matrix_h_0) self.matrix_low_1 = torch.tensor(matrix_h_1) self.matrix_low_2 = torch.tensor(matrix_h_2) self.matrix_high_0 = torch.tensor(matrix_g_0) self.matrix_high_1 = torch.tensor(matrix_g_1) self.matrix_high_2 = torch.tensor(matrix_g_2) def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH): assert len(LLL.size()) == len(LLH.size()) == len(LHL.size()) == len(LHH.size()) == 5 assert len(HLL.size()) == len(HLH.size()) == len(HHL.size()) == len(HHH.size()) == 5 self.input_depth = LLL.size()[-3] + HHH.size()[-3] self.input_height = LLL.size()[-2] + HHH.size()[-2] self.input_width = LLL.size()[-1] + HHH.size()[-1] #assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length self.get_matrix() return IDWTFunction_3D.apply(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2, self.matrix_high_0, self.matrix_high_1, self.matrix_high_2) # if __name__ == '__main__': # import pywt, cv2 # from datetime import datetime # # wavelet = pywt.Wavelet('bior1.1') # h = wavelet.rec_lo # g = wavelet.rec_hi # h_ = wavelet.dec_lo # g_ = wavelet.dec_hi # h_.reverse() # g_.reverse() # # #""" # image_full_name = '/home/liqiufu/Pictures/standard_test_images/lena_color_512.tif' # image = cv2.imread(image_full_name, flags = 1) # image = image[0:512,0:512,:] # print(image.shape) # height, width, channel = image.shape # #image = image.reshape((1,height,width)) # t0 = datetime.now() # for index in range(1): # m0 = DWT_2D(wavename = 'haar') # # m1 = IDWT_2D(wavename = 'haar') # print(isinstance(m1, IDWT_2D)) # t1 = datetime.now() class SegNet_VGG(nn.Module): def __init__(self, features, num_classes = 21, init_weights = True, wavename = None): super(SegNet_VGG, self).__init__() self.features = features[0] self.decoders = features[1] self.classifier_seg = nn.Sequential( #nn.Conv2d(64, 64, kernel_size = 3, padding = 1), #nn.ReLU(True), nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0), ) if init_weights: self._initialize_weights() def forward(self, x): xx = self.features(x) x, [(indices_1,), (indices_2,), (indices_3,), (indices_4,), (indices_5,)] = xx x = self.decoders(x, indices_5, indices_4, indices_3, indices_2, indices_1) x = self.classifier_seg(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None): # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu') if m.bias is not None: nn.init.constant_(m.bias, 0) else: print('Not initializing') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def __str__(self): return 'SegNet_VGG' class WSegNet_VGG(nn.Module): def __init__(self, features, num_classes, init_weights = True, wavename = None): super(WSegNet_VGG, self).__init__() self.features = features[0] self.decoders = features[1] self.classifier_seg = nn.Sequential( #nn.Conv2d(64, 64, kernel_size = 3, padding = 1), #nn.ReLU(True), nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0), ) if init_weights: self._initialize_weights() def forward(self, x): xx = self.features(x) x, [(LH1,HL1,HH1), (LH2,HL2,HH2,), (LH3,HL3,HH3,), (LH4,HL4,HH4,), (LH5,HL5,HH5,)] = xx x = self.decoders(x, LH5,HL5,HH5, LH4,HL4,HH4, LH3,HL3,HH3, LH2,HL2,HH2, LH1,HL1,HH1) x = self.classifier_seg(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None): # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu') if m.bias is not None: nn.init.constant_(m.bias, 0) else: print('Not initializing') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def __str__(self): return 'WSegNet_VGG' def make_layers(cfg, batch_norm = False): encoder = [] in_channels = 3 for v in cfg: if v != 'M': conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: encoder += [conv2d, nn.ReLU(inplace=True)] in_channels = v elif v == 'M': encoder += [nn.MaxPool2d(kernel_size = 2, stride = 2, return_indices = True)] encoder = My_Sequential(*encoder) decoder = [] cfg.reverse() out_channels_final = 64 for index, v in enumerate(cfg): if index != len(cfg) - 1: out_channels = cfg[index + 1] else: out_channels = out_channels_final if out_channels == 'M': out_channels = cfg[index + 2] if v == 'M': decoder += [nn.MaxUnpool2d(kernel_size = 2, stride = 2)] else: conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1) if batch_norm: decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)] else: decoder += [conv2d, nn.ReLU(inplace = True)] decoder = My_Sequential_re(*decoder) return encoder, decoder def make_w_layers(cfg, in_channels, batch_norm = False, wavename = 'haar'): encoder = [] for v in cfg: if v != 'M': conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: encoder += [conv2d, nn.ReLU(inplace=True)] in_channels = v elif v == 'M': encoder += [DWT_2D(wavename = wavename)] encoder = My_Sequential(*encoder) decoder = [] cfg.reverse() out_channels_final = 64 for index, v in enumerate(cfg): if index != len(cfg) - 1: out_channels = cfg[index + 1] else: out_channels = out_channels_final if out_channels == 'M': out_channels = cfg[index + 2] if v == 'M': decoder += [IDWT_2D(wavename = wavename)] else: conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1) if batch_norm: decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)] else: decoder += [conv2d, nn.ReLU(inplace = True)] decoder = My_Sequential_re(*decoder) return encoder, decoder cfg = { 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 11 layers 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 13 layers 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], # 16 layers out_channels for encoder, input_channels for decoder 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], # 19 layers } def segnet_vgg11(pretrained = False, **kwargs): """VGG 11-layer model (configuration "A") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['A']), **kwargs) return model def segnet_vgg11_bn(pretrained=False, **kwargs): """VGG 11-layer model (configuration "A") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['A'], batch_norm = True), **kwargs) return model def segnet_vgg13(pretrained=False, **kwargs): """VGG 13-layer model (configuration "B") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['B']), **kwargs) return model def segnet_vgg13_bn(pretrained=False, **kwargs): """VGG 13-layer model (configuration "B") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) return model def segnet_vgg16(pretrained=False, **kwargs): """VGG 16-layer model (configuration "D") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['D']), **kwargs) return model def segnet_vgg16_bn(pretrained=False, **kwargs): """VGG 16-layer model (configuration "D") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) return model def segnet_vgg19(pretrained=False, **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['E']), **kwargs) return model def segnet_vgg19_bn(pretrained=False, **kwargs): """VGG 19-layer model (configuration 'E') with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = SegNet_VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) return model """=================================================================================""" def wsegnet_vgg11(pretrained = False, wavename = 'haar', **kwargs): """VGG 11-layer model (configuration "A") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['A'], wavename = wavename), **kwargs) return model def wsegnet_vgg11_bn(pretrained=False, wavename = 'haar', **kwargs): """VGG 11-layer model (configuration "A") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['A'], batch_norm = True, wavename = wavename), **kwargs) return model def wsegnet_vgg13(pretrained=False, wavename = 'haar', **kwargs): """VGG 13-layer model (configuration "B") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['B'], wavename = wavename), **kwargs) return model def wsegnet_vgg13_bn(pretrained=False, wavename = 'haar', **kwargs): """VGG 13-layer model (configuration "B") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['B'], batch_norm=True, wavename = wavename), **kwargs) return model def wsegnet_vgg16(pretrained=False, wavename = 'haar', **kwargs): """VGG 16-layer model (configuration "D") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['D'], wavename = wavename), **kwargs) return model def wsegnet_vgg16_bn(in_channels, num_classes, pretrained=False, wavename = 'haar', **kwargs): """VGG 16-layer model (configuration "D") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['D'], in_channels, batch_norm=True, wavename = wavename), num_classes, **kwargs) return model def wsegnet_vgg19(pretrained=False, wavename = 'haar', **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['E'], wavename = wavename), **kwargs) return model def wsegnet_vgg19_bn(pretrained=False, wavename = 'haar', **kwargs): """VGG 19-layer model (configuration 'E') with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = WSegNet_VGG(make_w_layers(cfg['E'], batch_norm=True, wavename = wavename), **kwargs) return model # if __name__ == '__main__': # from loss.loss_function import segmentation_loss # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 128, 128).long() # model = wsegnet_vgg16_bn(1, 5) # model.train() # input1 = torch.rand(2, 1, 128, 128) # y = model(input1) # loss_train = criterion(y, mask) # loss_train.backward() # # print(output) # print(y.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: models/networks_2d/wds.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init import functools from torch.distributions.uniform import Uniform import numpy as np class basic_block(nn.Module): def __init__(self, ch_in, ch_out): super(basic_block, self).__init__() self.block = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False), nn.ReLU(inplace=True)) def forward(self, x): x = self.block(x) return x class WDS(nn.Module): def __init__(self, in_channels, num_classes): super(WDS, self).__init__() # branch1 self.b1_1 = basic_block(in_channels, 64) self.b1_2 = basic_block(64, 64) self.b1_3 = basic_block(64, 64) self.b1_4 = basic_block(64, 64) self.b1_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.b1_6 = basic_block(64, 128) self.b1_7 = basic_block(128, 128) self.b1_8 = basic_block(128, 128) self.b1_9 = basic_block(128, 128) self.b1_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # branch2 self.b2_1 = basic_block(in_channels, 64) self.b2_2 = basic_block(64, 64) self.b2_3 = basic_block(64, 64) self.b2_4 = basic_block(64, 64) self.b2_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.b2_6 = basic_block(64, 128) self.b2_7 = basic_block(128, 128) self.b2_8 = basic_block(128, 128) self.b2_9 = basic_block(128, 128) self.b2_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # branch3 self.b3_1 = basic_block(in_channels, 64) self.b3_2 = basic_block(64, 64) self.b3_3 = basic_block(64, 64) self.b3_4 = basic_block(64, 64) self.b3_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.b3_6 = basic_block(64, 128) self.b3_7 = basic_block(128, 128) self.b3_8 = basic_block(128, 128) self.b3_9 = basic_block(128, 128) self.b3_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # branch4 self.b4_1 = basic_block(in_channels, 64) self.b4_2 = basic_block(64, 64) self.b4_3 = basic_block(64, 64) self.b4_4 = basic_block(64, 64) self.b4_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.b4_6 = basic_block(64, 128) self.b4_7 = basic_block(128, 128) self.b4_8 = basic_block(128, 128) self.b4_9 = basic_block(128, 128) self.b4_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # output self.output_layer = nn.Sequential( nn.Conv2d(128*4, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1, bias=False), ) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, LL, LH, HL, HH): # H, W = 2*LL.shape[2], 2*LL.shape[3] H, W = LL.shape[2], LL.shape[3] LL = self.b1_1(LL) LL = self.b1_2(LL) LL = self.b1_3(LL) LL = self.b1_4(LL) LL = self.b1_5(LL) LL = self.b1_6(LL) LL = self.b1_7(LL) LL = self.b1_8(LL) LL = self.b1_9(LL) LL = self.b1_10(LL) LH = self.b2_1(LH) LH = self.b2_2(LH) LH = self.b2_3(LH) LH = self.b2_4(LH) LH = self.b2_5(LH) LH = self.b2_6(LH) LH = self.b2_7(LH) LH = self.b2_8(LH) LH = self.b2_9(LH) LH = self.b2_10(LH) HL = self.b3_1(HL) HL = self.b3_2(HL) HL = self.b3_3(HL) HL = self.b3_4(HL) HL = self.b3_5(HL) HL = self.b3_6(HL) HL = self.b3_7(HL) HL = self.b3_8(HL) HL = self.b3_9(HL) HL = self.b3_10(HL) HH = self.b4_1(HH) HH = self.b4_2(HH) HH = self.b4_3(HH) HH = self.b4_4(HH) HH = self.b4_5(HH) HH = self.b4_6(HH) HH = self.b4_7(HH) HH = self.b4_8(HH) HH = self.b4_9(HH) HH = self.b4_10(HH) x = torch.cat((LL, LH, HL, HH), dim=1) x = self.output_layer(x) x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x if __name__ == '__main__': from loss.loss_function import segmentation_loss criterion = segmentation_loss('dice', False) mask = torch.ones(2, 128, 128).long() model = WDS(1, 5) model.train() input1 = torch.rand(2, 1, 128, 128) y = model(input1, input1, input1, input1) loss_train = criterion(y, mask) loss_train.backward() # print(output) print(y.data.cpu().numpy().shape) print(loss_train) ================================================ FILE: models/networks_2d/xnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init import functools from torch.distributions.uniform import Uniform import numpy as np BatchNorm2d = nn.BatchNorm2d relu_inplace = True BN_MOMENTUM = 0.1 # BN_MOMENTUM = 0.01 def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) class up_conv(nn.Module): def __init__(self, ch_in, ch_out): super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), BatchNorm2d(ch_out, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace) ) def forward(self, x): x = self.up(x) return x class down_conv(nn.Module): def __init__(self, ch_in, ch_out): super(down_conv, self).__init__() self.down = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False), BatchNorm2d(ch_out, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace) ) def forward(self, x): x = self.down(x) return x class same_conv(nn.Module): def __init__(self, ch_in, ch_out): super(same_conv, self).__init__() self.same = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False), BatchNorm2d(ch_out, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace)) def forward(self, x): x = self.same(x) return x class transition_conv(nn.Module): def __init__(self, ch_in, ch_out): super(transition_conv, self).__init__() self.transition = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False), BatchNorm2d(ch_out, momentum=BN_MOMENTUM), nn.ReLU(inplace=relu_inplace)) def forward(self, x): x = self.transition(x) return x class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=relu_inplace) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) # out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out = self.bn2(out) + identity out = self.relu(out) return out class DoubleBasicBlock(nn.Module): def __init__(self, inplanes, planes, downsample=None): super(DoubleBasicBlock, self).__init__() self.DBB = nn.Sequential( BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample), BasicBlock(inplanes=planes, planes=planes) ) def forward(self, x): out = self.DBB(x) return out class XNet(nn.Module): def __init__(self, in_channels, num_classes): super(XNet, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = DoubleBasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b2_4_5 = DoubleBasicBlock(l4c, l4c) self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4_1 = self.b1_3_2_down(x1_3) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4_1 = self.b2_3_2_down(x2_3) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) # decode # branch1 x1_3 = torch.cat((x1_3, x1_4_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_3 = torch.cat((x2_3, x2_4_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_1_1_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_1_1_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_4_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_4_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4 = self.b1_3_2_down(x1_3) x1_4 = self.b1_4_1(x1_4) x1_5_1 = self.b1_4_2_down(x1_4) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4 = self.b2_3_2_down(x2_3) x2_4 = self.b2_4_1(x2_4) x2_5_1 = self.b2_4_2_down(x2_4) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) # decode # branch1 x1_4 = torch.cat((x1_4, x1_5_3), dim=1) x1_4 = self.b1_4_3(x1_4) x1_4 = self.b1_4_4_up(x1_4) x1_3 = torch.cat((x1_3, x1_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_4 = torch.cat((x2_4, x2_5_3), dim=1) x2_4 = self.b2_4_3(x2_4) x2_4 = self.b2_4_4_up(x2_4) x2_3 = torch.cat((x2_3, x2_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_1_2_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_1_2_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_4_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = DoubleBasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c, l4c) self.b2_4_5 = DoubleBasicBlock(l4c, l4c) self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4 = self.b1_3_2_down(x1_3) x1_4 = self.b1_4_1(x1_4) x1_5_1 = self.b1_4_2_down(x1_4) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4_1 = self.b2_3_2_down(x2_3) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) # decode # branch1 x1_4 = torch.cat((x1_4, x1_5_3), dim=1) x1_4 = self.b1_4_3(x1_4) x1_4 = self.b1_4_4_up(x1_4) x1_3 = torch.cat((x1_3, x1_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_3 = torch.cat((x2_3, x2_4_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_2_1_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_2_1_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_4_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4_1 = self.b1_3_2_down(x1_3) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4 = self.b2_3_2_down(x2_3) x2_4 = self.b2_4_1(x2_4) x2_5_1 = self.b2_4_2_down(x2_4) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) # decode # branch1 x1_3 = torch.cat((x1_3, x1_4_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_4 = torch.cat((x2_4, x2_5_3), dim=1) x2_4 = self.b2_4_3(x2_4) x2_4 = self.b2_4_4_up(x2_4) x2_3 = torch.cat((x2_3, x2_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_2_3_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_2_3_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_3_up = up_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_up_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_2 = DoubleBasicBlock(l3c, l3c) self.b2_3_3_down = down_conv(l3c, l3c) self.b2_3_3_down_down = down_conv(l3c, l3c) self.b2_3_3_same = same_conv(l3c, l3c) self.b2_3_4_transition = transition_conv(l3c+l5c+l4c, l3c) self.b2_3_5 = DoubleBasicBlock(l3c, l3c) self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_7_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = DoubleBasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b2_4_5 = DoubleBasicBlock(l4c, l4c) self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4_1 = self.b1_3_2_down(x1_3) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_4_3_up = self.b1_4_3_up(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3_1 = self.b2_2_2_down(x2_2) x2_3_1 = self.b2_3_1(x2_3_1) x2_3_2 = self.b2_3_2(x2_3_1) x2_3_3_down = self.b2_3_3_down(x2_3_2) x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down) x2_3_3_same = self.b2_3_3_same(x2_3_2) x2_4_1 = self.b2_3_2_down(x2_3_1) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_4_3_down, x1_5_2_same), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) x2_3_4 = torch.cat((x2_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1) x2_3_4 = self.b2_3_4_transition(x2_3_4) x2_3_4 = self.b2_3_5(x2_3_4) x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1) x2_3_4 = self.b2_3_6(x2_3_4) x2_3_4 = self.b2_3_7_up(x2_3_4) # decode # branch1 x1_3 = torch.cat((x1_3, x1_4_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_2 = torch.cat((x2_2, x2_3_4), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_3_2_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_3_2_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_2 = DoubleBasicBlock(l3c, l3c) self.b1_3_3_down = down_conv(l3c, l3c) self.b1_3_3_down_down = down_conv(l3c, l3c) self.b1_3_3_same = same_conv(l3c, l3c) self.b1_3_4_transition = transition_conv(l3c+l5c+l4c, l3c) self.b1_3_5 = DoubleBasicBlock(l3c, l3c) self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_7_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = DoubleBasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_3_up = up_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c) self.b2_4_5 = DoubleBasicBlock(l4c, l4c) self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_up_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3_1 = self.b1_2_2_down(x1_2) x1_3_1 = self.b1_3_1(x1_3_1) x1_3_2 = self.b1_3_2(x1_3_1) x1_3_3_down = self.b1_3_3_down(x1_3_2) x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down) x1_3_3_same = self.b1_3_3_same(x1_3_2) x1_4_1 = self.b1_3_2_down(x1_3_1) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4_1 = self.b2_3_2_down(x2_3) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_4_3_up = self.b2_4_3_up(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_4_3_down, x2_5_2_same), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) x1_3_4 = torch.cat((x1_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1) x1_3_4 = self.b1_3_4_transition(x1_3_4) x1_3_4 = self.b1_3_5(x1_3_4) x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1) x1_3_4 = self.b1_3_6(x1_3_4) x1_3_4 = self.b1_3_7_up(x1_3_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) # decode # branch1 x1_2 = torch.cat((x1_2, x1_3_4), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_3 = torch.cat((x2_3, x2_4_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_3_3_m(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_3_3_m, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_2 = DoubleBasicBlock(l3c, l3c) self.b1_3_3_down = down_conv(l3c, l3c) self.b1_3_3_down_down = down_conv(l3c, l3c) self.b1_3_3_same = same_conv(l3c, l3c) self.b1_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c) self.b1_3_5 = DoubleBasicBlock(l3c, l3c) self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_7_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_3_up = up_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_up_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = DoubleBasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = DoubleBasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_2 = DoubleBasicBlock(l3c, l3c) self.b2_3_3_down = down_conv(l3c, l3c) self.b2_3_3_down_down = down_conv(l3c, l3c) self.b2_3_3_same = same_conv(l3c, l3c) self.b2_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c) self.b2_3_5 = DoubleBasicBlock(l3c, l3c) self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b2_3_7_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = DoubleBasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = DoubleBasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_3_up = up_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c) self.b2_4_5 = DoubleBasicBlock(l4c, l4c) self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = DoubleBasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_up_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c) self.b2_5_4 = DoubleBasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3_1 = self.b1_2_2_down(x1_2) x1_3_1 = self.b1_3_1(x1_3_1) x1_3_2 = self.b1_3_2(x1_3_1) x1_3_3_down = self.b1_3_3_down(x1_3_2) x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down) x1_3_3_same = self.b1_3_3_same(x1_3_2) x1_4_1 = self.b1_3_2_down(x1_3_1) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_4_3_up = self.b1_4_3_up(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3_1 = self.b2_2_2_down(x2_2) x2_3_1 = self.b2_3_1(x2_3_1) x2_3_2 = self.b2_3_2(x2_3_1) x2_3_3_down = self.b2_3_3_down(x2_3_2) x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down) x2_3_3_same = self.b2_3_3_same(x2_3_2) x2_4_1 = self.b2_3_2_down(x2_3_1) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_4_3_up = self.b2_4_3_up(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) x1_3_4 = torch.cat((x1_3_3_same, x2_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1) x1_3_4 = self.b1_3_4_transition(x1_3_4) x1_3_4 = self.b1_3_5(x1_3_4) x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1) x1_3_4 = self.b1_3_6(x1_3_4) x1_3_4 = self.b1_3_7_up(x1_3_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) x2_3_4 = torch.cat((x2_3_3_same, x1_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1) x2_3_4 = self.b2_3_4_transition(x2_3_4) x2_3_4 = self.b2_3_5(x2_3_4) x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1) x2_3_4 = self.b2_3_6(x2_3_4) x2_3_4 = self.b2_3_7_up(x2_3_4) # decode # branch1 x1_2 = torch.cat((x1_2, x1_3_4), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_2 = torch.cat((x2_2, x2_3_4), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 class XNet_sb(nn.Module): def __init__(self, in_channels, num_classes): super(XNet_sb, self).__init__() l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = DoubleBasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = DoubleBasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = DoubleBasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = DoubleBasicBlock(l4c, l4c) # self.b1_4_3_down = down_conv(l4c, l4c) # self.b1_4_3_same = same_conv(l4c, l4c) # self.b1_4_4_transition = transition_conv(l4c, l4c) self.b1_4_5 = DoubleBasicBlock(l4c, l4c) self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = DoubleBasicBlock(l5c, l5c) # self.b1_5_2_up = up_conv(l5c, l5c) # self.b1_5_2_same = same_conv(l5c, l5c) # self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b1_5_4 = DoubleBasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4_1 = self.b1_3_2_down(x1_3) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_2 = self.b1_4_5(x1_4_2) # x1_4_3_down = self.b1_4_3_down(x1_4_2) # x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_1 = self.b1_5_4(x1_5_1) x1_5_1 = self.b1_5_5_up(x1_5_1) # x1_5_2_up = self.b1_5_2_up(x1_5_1) # x1_5_2_same = self.b1_5_2_same(x1_5_1) # decode # branch1 x1_4_2 = torch.cat((x1_4_2, x1_5_1), dim=1) x1_4_2 = self.b1_4_6(x1_4_2) x1_4_2 = self.b1_4_7_up(x1_4_2) x1_3 = torch.cat((x1_3, x1_4_2), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) return x1_1 # if __name__ == '__main__': # model = XNet(1, 10) # total = sum([param.nelement() for param in model.parameters()]) # from thop import profile, clever_format # # input = torch.randn(1, 1, 128, 128) # flops, params = profile(model, inputs=(input, input, )) # macs, params = clever_format([flops, params], "%.3f") # print(macs) # print(params) # print(total) # model.eval() # input1 = torch.rand(2,3,256,256) # input2 = torch.rand(2,1,256,256) # x1_1, x2_1 = model(input1, input2) # output1 = x1_1.data.cpu().numpy() # output2 = x2_1.data.cpu().numpy() # # print(output) # print(output1.shape) # print(output2.shape) ================================================ FILE: models/networks_3d/__init__.py ================================================ ================================================ FILE: models/networks_3d/conresnet.py ================================================ import torch.nn as nn from torch.nn import functional as F import torch import numpy as np from torch.nn import init # from loss.loss_function import segmentation_loss def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class Conv3d(nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False): super(Conv3d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, x): weight = self.weight weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) weight = weight - weight_mean std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1) weight = weight / std.expand_as(weight) return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def conv3x3x3(in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1,1,1), dilation=(1,1,1), bias=False, weight_std=False): "3x3x3 convolution with padding" if weight_std: return Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) else: return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) class ConResAtt(nn.Module): def __init__(self, in_channels, in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1), bias=False, weight_std=False, first_layer=False): super(ConResAtt, self).__init__() self.weight_std = weight_std self.stride = stride self.in_planes = in_planes self.out_planes = out_planes self.first_layer = first_layer self.relu = nn.ReLU(inplace=True) self.gn_seg = nn.GroupNorm(8, in_planes) self.conv_seg = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]), stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) self.gn_res = nn.GroupNorm(8, out_planes) self.conv_res = conv3x3x3(out_planes, out_planes, kernel_size=(1,1,1), stride=(1, 1, 1), padding=(0,0,0), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) self.gn_res1 = nn.GroupNorm(8, out_planes) self.conv_res1 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]), stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) self.gn_res2 = nn.GroupNorm(8, out_planes) self.conv_res2 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]), stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) self.gn_mp = nn.GroupNorm(8, in_planes) self.conv_mp_first = conv3x3x3(in_channels, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]), stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) self.conv_mp = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]), stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]), dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std) def _res(self, x): # bs, channel, D, W, H bs, channel, depth, heigt, width = x.shape # x_copy = torch.zeros_like(x).cuda() x_copy = torch.zeros_like(x) x_copy[:, :, 1:, :, :] = x[:, :, 0: depth - 1, :, :] res = x - x_copy res[:, :, 0, :, :] = 0 res = torch.abs(res) return res def forward(self, input): x1, x2 = input if self.first_layer: x1 = self.gn_seg(x1) x1 = self.relu(x1) x1 = self.conv_seg(x1) res = torch.sigmoid(x1) res = self._res(res) res = self.conv_res(res) x2 = self.conv_mp_first(x2) x2 = x2 + res else: x1 = self.gn_seg(x1) x1 = self.relu(x1) x1 = self.conv_seg(x1) res = torch.sigmoid(x1) res = self._res(res) res = self.conv_res(res) if self.in_planes != self.out_planes: x2 = self.gn_mp(x2) x2 = self.relu(x2) x2 = self.conv_mp(x2) x2 = x2 + res x2 = self.gn_res1(x2) x2 = self.relu(x2) x2 = self.conv_res1(x2) x1 = x1*(1 + torch.sigmoid(x2)) return [x1, x2] class NoBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=(1, 1, 1), dilation=(1, 1, 1), downsample=None, fist_dilation=1, multi_grid=1, weight_std=False): super(NoBottleneck, self).__init__() self.weight_std = weight_std self.relu = nn.ReLU(inplace=True) self.gn1 = nn.GroupNorm(8, inplanes) self.conv1 = conv3x3x3(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=dilation * multi_grid, dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std) self.gn2 = nn.GroupNorm(8, planes) self.conv2 = conv3x3x3(planes, planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=dilation * multi_grid, dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std) self.downsample = downsample self.dilation = dilation self.stride = stride def forward(self, x): skip = x seg = self.gn1(x) seg = self.relu(seg) seg = self.conv1(seg) seg = self.gn2(seg) seg = self.relu(seg) seg = self.conv2(seg) if self.downsample is not None: skip = self.downsample(x) seg = seg + skip return seg class ConResNet(nn.Module): def __init__(self, in_channels, num_classes, shape, block, layers, weight_std=False): self.shape = shape self.weight_std = weight_std super(ConResNet, self).__init__() self.conv_4_32 = nn.Sequential( conv3x3x3(in_channels, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), weight_std=self.weight_std)) self.conv_32_64 = nn.Sequential( nn.GroupNorm(8, 32), nn.ReLU(inplace=True), conv3x3x3(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std)) self.conv_64_128 = nn.Sequential( nn.GroupNorm(8, 64), nn.ReLU(inplace=True), conv3x3x3(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std)) self.conv_128_256 = nn.Sequential( nn.GroupNorm(8, 128), nn.ReLU(inplace=True), conv3x3x3(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std)) self.layer0 = self._make_layer(block, 32, 32, layers[0], stride=(1, 1, 1)) self.layer1 = self._make_layer(block, 64, 64, layers[1], stride=(1, 1, 1)) self.layer2 = self._make_layer(block, 128, 128, layers[2], stride=(1, 1, 1)) self.layer3 = self._make_layer(block, 256, 256, layers[3], stride=(1, 1, 1)) self.layer4 = self._make_layer(block, 256, 256, layers[4], stride=(1, 1, 1), dilation=(2,2,2)) self.fusionConv = nn.Sequential( nn.GroupNorm(8, 256), nn.ReLU(inplace=True), nn.Dropout3d(0.1), conv3x3x3(256, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1), weight_std=self.weight_std) ) self.seg_x4 = nn.Sequential( ConResAtt(in_channels, 128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std, first_layer=True)) self.seg_x2 = nn.Sequential( ConResAtt(in_channels, 64, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std)) self.seg_x1 = nn.Sequential( ConResAtt(in_channels, 32, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std)) self.seg_cls = nn.Sequential( nn.Conv3d(32, num_classes, kernel_size=1) ) self.res_cls = nn.Sequential( nn.Conv3d(32, num_classes, kernel_size=1) ) self.resx2_cls = nn.Sequential( nn.Conv3d(32, num_classes, kernel_size=1) ) self.resx4_cls = nn.Sequential( nn.Conv3d(64, num_classes, kernel_size=1) ) def _make_layer(self, block, inplanes, outplanes, blocks, stride=(1, 1, 1), dilation=(1, 1, 1), multi_grid=1): downsample = None if stride[0] != 1 or stride[1] != 1 or stride[2] != 1 or inplanes != outplanes: downsample = nn.Sequential( nn.GroupNorm(8, inplanes), nn.ReLU(inplace=True), conv3x3x3(inplanes, outplanes, kernel_size=(1, 1, 1), stride=stride, padding=(0, 0, 0), weight_std=self.weight_std) ) layers = [] generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1 layers.append(block(inplanes, outplanes, stride, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid), weight_std=self.weight_std)) for i in range(1, blocks): layers.append( block(inplanes, outplanes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid), weight_std=self.weight_std)) return nn.Sequential(*layers) def forward(self, x, x_res): ## encoder x = self.conv_4_32(x) x = self.layer0(x) skip1 = x x = self.conv_32_64(x) x = self.layer1(x) skip2 = x x = self.conv_64_128(x) x = self.layer2(x) skip3 = x x = self.conv_128_256(x) x = self.layer3(x) x = self.layer4(x) x = self.fusionConv(x) ## decoder res_x4 = F.interpolate(x_res, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True) seg_x4 = F.interpolate(x, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True) seg_x4 = seg_x4 + skip3 seg_x4, res_x4 = self.seg_x4([seg_x4, res_x4]) res_x2 = F.interpolate(res_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True) seg_x2 = F.interpolate(seg_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True) seg_x2 = seg_x2 + skip2 seg_x2, res_x2 = self.seg_x2([seg_x2, res_x2]) res_x1 = F.interpolate(res_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True) seg_x1 = F.interpolate(seg_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True) seg_x1 = seg_x1 + skip1 seg_x1, res_x1 = self.seg_x1([seg_x1, res_x1]) seg = self.seg_cls(seg_x1) res = self.res_cls(res_x1) resx2 = self.resx2_cls(res_x2) resx4 = self.resx4_cls(res_x4) resx2 = F.interpolate(resx2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True) resx4 = F.interpolate(resx4, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True) return [seg, res, resx2, resx4] def conresnet(in_channels, num_classes, **kwargs): model = ConResNet(in_channels, num_classes, kwargs['img_shape'], NoBottleneck, [1, 2, 2, 2, 2]) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(5, 64, 64, 64).long() # model = conresnet(1, 10, img_shape=(64, 64, 64)) # model.train() # output = model(torch.rand(5, 1, 64, 64, 64), torch.rand(5, 1, 64, 64, 64)) # # loss_train_1 = criterion(output[0], mask) # loss_train_2 = criterion(output[1], mask) # loss_train_3 = criterion(output[2], mask) # loss_train_4 = criterion(output[3], mask) # # loss_train_1.backward() # # print(output[0].data.cpu().numpy().shape) # print(output[1].data.cpu().numpy().shape) # print(output[2].data.cpu().numpy().shape) # print(output[3].data.cpu().numpy().shape) # print(loss_train_1) # print(loss_train_2) # print(loss_train_3) # print(loss_train_4) ================================================ FILE: models/networks_3d/cotr.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.nn.init import xavier_uniform_, constant_, normal_ import copy import math # from loss.loss_function import segmentation_loss class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=[64, 64, 64], temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x): bs, c, d, h, w = x.shape mask = torch.zeros(bs, d, h, w, dtype=torch.bool).cuda() # mask = torch.zeros(bs, d, h, w, dtype=torch.bool) assert mask is not None not_mask = ~mask d_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(3, dtype=torch.float32) if self.normalize: eps = 1e-6 d_embed = (d_embed - 0.5) / (d_embed[:, -1:, :, :] + eps) * self.scale y_embed = (y_embed - 0.5) / (y_embed[:, :, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, :, -1:] + eps) * self.scale dim_tx = torch.arange(self.num_pos_feats[0], dtype=torch.float32, device=x.device) dim_tx = self.temperature ** (3 * (dim_tx // 3) / self.num_pos_feats[0]) dim_ty = torch.arange(self.num_pos_feats[1], dtype=torch.float32, device=x.device) dim_ty = self.temperature ** (3 * (dim_ty // 3) / self.num_pos_feats[1]) dim_td = torch.arange(self.num_pos_feats[2], dtype=torch.float32, device=x.device) dim_td = self.temperature ** (3 * (dim_td // 3) / self.num_pos_feats[2]) pos_x = x_embed[:, :, :, :, None] / dim_tx pos_y = y_embed[:, :, :, :, None] / dim_ty pos_d = d_embed[:, :, :, :, None] / dim_td pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_d = torch.stack((pos_d[:, :, :, :, 0::2].sin(), pos_d[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos = torch.cat((pos_d, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3) return pos def build_position_encoding(mode, hidden_dim): N_steps = hidden_dim // 3 if (hidden_dim % 3) != 0: N_steps = [N_steps, N_steps, N_steps + hidden_dim % 3] else: N_steps = [N_steps, N_steps, N_steps] if mode in ('v2', 'sine'): position_embedding = PositionEmbeddingSine(num_pos_feats=N_steps, normalize=True) else: raise ValueError(f"not supported {mode}") return position_embedding def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights): N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 # sampling_grids = 3 * sampling_locations - 1 sampling_value_list = [] for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes): value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_) sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:] sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_.to(dtype=value_l_.dtype), mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0] sampling_value_list.append(sampling_value_l_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 3) :param input_flatten (N, \sum_{l=0}^{L-1} D_l \cdot H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 3), [(D_0, H_0, W_0), (D_1, H_1, W_1), ..., (D_{L-1}, H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, D_0*H_0*W_0, D_0*H_0*W_0+D_1*H_1*W_1, D_0*H_0*W_0+D_1*H_1*W_1+D_2*H_2*W_2, ..., D_0*H_0*W_0+D_1*H_1*W_1+...+D_{L-1}*H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} D_l \cdot H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) if reference_points.shape[-1] == 3: offset_normalizer = torch.stack([input_spatial_shapes[..., 0], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1) sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) return output def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class DeformableTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4): super().__init__() # self attention self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src class DeformableTransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (D_, H_, W_) in enumerate(spatial_shapes): ref_d, ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, D_ - 0.5, D_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_d = ref_d.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * D_) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 2] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * W_) ref = torch.stack((ref_d, ref_x, ref_y), -1) # D W H reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): output = src reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) for _, layer in enumerate(self.layers): output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) return output class DeformableTransformer(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, D, H, W = mask.shape valid_D = torch.sum(~mask[:, :, 0, 0], 1) valid_H = torch.sum(~mask[:, 0, :, 0], 1) valid_W = torch.sum(~mask[:, 0, 0, :], 1) valid_ratio_d = valid_D.float() / D valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_d, valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds): # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, d, h, w = src.shape spatial_shape = (d, h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) return memory class Conv3d_wd(nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, bias=False): super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, x): weight = self.weight weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) weight = weight - weight_mean # std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1) weight = weight / std.expand_as(weight) return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, bias=False, weight_std=False): "3x3x3 convolution with padding" if weight_std: return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) else: return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) def Norm_layer(norm_cfg, inplanes): if norm_cfg == 'BN': out = nn.BatchNorm3d(inplanes) elif norm_cfg == 'SyncBN': out = nn.SyncBatchNorm(inplanes) elif norm_cfg == 'GN': out = nn.GroupNorm(16, inplanes) elif norm_cfg == 'IN': out = nn.InstanceNorm3d(inplanes, affine=True) return out def Activation_layer(activation_cfg, inplace=True): if activation_cfg == 'ReLU': out = nn.ReLU(inplace=inplace) elif activation_cfg == 'LeakyReLU': out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace) return out class ResBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False): super(ResBlock, self).__init__() self.conv1 = conv3x3x3(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std) self.norm1 = Norm_layer(norm_cfg, planes) self.nonlin = Activation_layer(activation_cfg, inplace=True) self.downsample = downsample def forward(self, x): residual = x out = self.conv1(x) out = self.norm1(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.nonlin(out) return out class Backbone(nn.Module): def __init__(self, depth, in_channels=1, norm_cfg='BN', activation_cfg='ReLU', weight_std=False): super(Backbone, self).__init__() self.arch_settings = { 9: (ResBlock, (3, 3, 2)) } if depth not in self.arch_settings: raise KeyError('invalid depth {} for resnet'.format(depth)) self.depth = depth block, layers = self.arch_settings[depth] self.inplanes = 64 self.conv1 = conv3x3x3(in_channels, 64, kernel_size=7, stride=(1, 2, 2), padding=3, bias=False, weight_std=weight_std) self.norm1 = Norm_layer(norm_cfg, 64) self.nonlin = Activation_layer(activation_cfg, inplace=True) self.layer1 = self._make_layer(block, 192, layers[0], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) self.layer2 = self._make_layer(block, 384, layers[1], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) self.layer3 = self._make_layer(block, 384, layers[2], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) self.layers = [] for m in self.modules(): if isinstance(m, (nn.Conv3d, Conv3d_wd)): m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=(1, 1, 1), norm_cfg='BN', activation_cfg='ReLU', weight_std=False): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv3x3x3( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, weight_std=weight_std), Norm_layer(norm_cfg, planes * block.expansion)) layers = [] layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, stride=stride, downsample=downsample, weight_std=weight_std)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, weight_std=weight_std)) return nn.Sequential(*layers) def init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv3d, Conv3d_wd)): m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): out = [] x = self.conv1(x) x = self.norm1(x) x = self.nonlin(x) out.append(x) x = self.layer1(x) out.append(x) x = self.layer2(x) out.append(x) x = self.layer3(x) out.append(x) return out class Conv3dBlock(nn.Module): def __init__(self, in_channels, out_channels, norm_cfg, activation_cfg, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), bias=False, weight_std=False): super(Conv3dBlock, self).__init__() self.conv = conv3x3x3(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, weight_std=weight_std) self.norm = Norm_layer(norm_cfg, out_channels) self.nonlin = Activation_layer(activation_cfg, inplace=True) def forward(self, x): x = self.conv(x) x = self.norm(x) x = self.nonlin(x) return x class ResBlock_(nn.Module): def __init__(self, inplanes, planes, norm_cfg, activation_cfg, weight_std=False): super(ResBlock_, self).__init__() self.resconv1 = Conv3dBlock(inplanes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std) self.resconv2 = Conv3dBlock(planes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std) def forward(self, x): residual = x out = self.resconv1(x) out = self.resconv2(out) out = out + residual return out class U_ResTran3D(nn.Module): def __init__(self, in_channels, num_classes, norm_cfg='BN', activation_cfg='ReLU', weight_std=False): super(U_ResTran3D, self).__init__() self.MODEL_NUM_CLASSES = num_classes self.upsamplex2 = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True) self.transposeconv_stage2 = nn.ConvTranspose3d(384, 384, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) self.transposeconv_stage1 = nn.ConvTranspose3d(384, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) self.transposeconv_stage0 = nn.ConvTranspose3d(192, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) self.stage2_de = ResBlock_(384, 384, norm_cfg, activation_cfg, weight_std=weight_std) self.stage1_de = ResBlock_(192, 192, norm_cfg, activation_cfg, weight_std=weight_std) self.stage0_de = ResBlock_(64, 64, norm_cfg, activation_cfg, weight_std=weight_std) # self.ds2_cls_conv = nn.Conv3d(384, self.MODEL_NUM_CLASSES, kernel_size=1) # self.ds1_cls_conv = nn.Conv3d(192, self.MODEL_NUM_CLASSES, kernel_size=1) # self.ds0_cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1) self.cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1) for m in self.modules(): if isinstance(m, (nn.Conv3d, Conv3d_wd, nn.ConvTranspose3d)): m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm, nn.InstanceNorm3d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) self.backbone = Backbone(depth=9, in_channels=in_channels, norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) # total = sum([param.nelement() for param in self.backbone.parameters()]) # print(' + Number of Backbone Params: %.2f(e6)' % (total / 1e6)) self.position_embed = build_position_encoding(mode='v2', hidden_dim=384) self.encoder_Detrans = DeformableTransformer(d_model=384, dim_feedforward=1536, dropout=0.1, activation='gelu', num_feature_levels=2, nhead=6, num_encoder_layers=6, enc_n_points=4) # total = sum([param.nelement() for param in self.encoder_Detrans.parameters()]) # print(' + Number of Transformer Params: %.2f(e6)' % (total / 1e6)) def posi_mask(self, x): x_fea = [] x_posemb = [] masks = [] for lvl, fea in enumerate(x): if lvl > 1: x_fea.append(fea) x_posemb.append(self.position_embed(fea)) masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool).cuda()) # masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool)) return x_fea, masks, x_posemb def forward(self, inputs): # # %%%%%%%%%%%%% CoTr x_convs = self.backbone(inputs) x_fea, masks, x_posemb = self.posi_mask(x_convs) x_trans = self.encoder_Detrans(x_fea, masks, x_posemb) # # Single_scale # # x = self.transposeconv_stage2(x_trans.transpose(-1, -2).view(x_convs[-1].shape)) # # skip2 = x_convs[-2] # Multi-scale x = self.transposeconv_stage2(x_trans[:, x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]::].transpose(-1, -2).view(x_convs[-1].shape)) # x_trans length: 12*24*24+6*12*12=7776 skip2 = x_trans[:, 0:x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]].transpose(-1, -2).view(x_convs[-2].shape) x = x + skip2 x = self.stage2_de(x) # ds2 = self.ds2_cls_conv(x) x = self.transposeconv_stage1(x) skip1 = x_convs[-3] x = x + skip1 x = self.stage1_de(x) # ds1 = self.ds1_cls_conv(x) x = self.transposeconv_stage0(x) skip0 = x_convs[-4] x = x + skip0 x = self.stage0_de(x) # ds0 = self.ds0_cls_conv(x) result = self.upsamplex2(x) result = self.cls_conv(result) return result def cotr(in_channels, num_classes): model = U_ResTran3D(in_channels, num_classes) return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # # mask = torch.ones(2, 64, 64, 64).long() # model = cotr(1, 10) # model.train() # input = torch.rand(2, 1, 64, 64, 64) # output = model(input) # loss_train = criterion(output, mask) # output = output.data.cpu().numpy() # loss_train.backward() # print(output.shape) # print(loss_train) ================================================ FILE: models/networks_3d/dmfnet.py ================================================ import torch.nn as nn import torch.nn.functional as F import torch # from loss.loss_function import segmentation_loss def normalization(planes, norm='bn'): if norm == 'bn': m = nn.BatchNorm3d(planes) elif norm == 'gn': m = nn.GroupNorm(4, planes) elif norm == 'in': m = nn.InstanceNorm3d(planes) else: raise ValueError('normalization type {} is not supported'.format(norm)) return m class Conv3d_Block(nn.Module): def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=None,norm=None): super(Conv3d_Block, self).__init__() if padding == None: padding = (kernel_size - 1) // 2 self.bn = normalization(num_in,norm=norm) self.act_fn = nn.ReLU(inplace=True) self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding,stride=stride, groups=g, bias=False) def forward(self, x): # BN + Relu + Conv h = self.act_fn(self.bn(x)) h = self.conv(h) return h class DilatedConv3DBlock(nn.Module): def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1, d=(1,1,1), norm=None): super(DilatedConv3DBlock, self).__init__() assert isinstance(kernel_size,tuple) and isinstance(d,tuple) padding = tuple( [(ks-1)//2 *dd for ks, dd in zip(kernel_size, d)] ) self.bn = normalization(num_in, norm=norm) self.act_fn = nn.ReLU(inplace=True) self.conv = nn.Conv3d(num_in,num_out,kernel_size=kernel_size,padding=padding,stride=stride,groups=g,dilation=d,bias=False) def forward(self, x): h = self.act_fn(self.bn(x)) h = self.conv(h) return h class MFunit(nn.Module): def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None): """ The second 3x3x1 group conv is replaced by 3x3x3. :param num_in: number of input channels :param num_out: number of output channels :param g: groups of group conv. :param stride: 1 or 2 :param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv :param norm: Batch Normalization """ super(MFunit, self).__init__() num_mid = num_in if num_in <= num_out else num_out self.conv1x1x1_in1 = Conv3d_Block(num_in,num_in//4,kernel_size=1,stride=1,norm=norm) self.conv1x1x1_in2 = Conv3d_Block(num_in//4,num_mid,kernel_size=1,stride=1,norm=norm) self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid,num_out,kernel_size=(3,3,3),stride=stride,g=g,d=(d[0],d[0],d[0]),norm=norm) # dilated self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(3,3,1),stride=1,g=g,d=(d[1],d[1],1),norm=norm) # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm) # skip connection if num_in != num_out or stride != 1: if stride == 1: self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0,norm=norm) if stride == 2: # if MF block with stride=2, 2x2x2 self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2,padding=0, norm=norm) # params def forward(self, x): x1 = self.conv1x1x1_in1(x) x2 = self.conv1x1x1_in2(x1) x3 = self.conv3x3x3_m1(x2) x4 = self.conv3x3x3_m2(x3) shortcut = x if hasattr(self,'conv1x1x1_shortcut'): shortcut = self.conv1x1x1_shortcut(shortcut) if hasattr(self,'conv2x2x2_shortcut'): shortcut = self.conv2x2x2_shortcut(shortcut) return x4 + shortcut class DMFUnit(nn.Module): # weighred add def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=None): super(DMFUnit, self).__init__() self.weight1 = nn.Parameter(torch.ones(1)) self.weight2 = nn.Parameter(torch.ones(1)) self.weight3 = nn.Parameter(torch.ones(1)) num_mid = num_in if num_in <= num_out else num_out self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm) self.conv1x1x1_in2 = Conv3d_Block(num_in // 4,num_mid,kernel_size=1, stride=1, norm=norm) self.conv3x3x3_m1 = nn.ModuleList() if dilation == None: dilation = [1,2,3] for i in range(3): self.conv3x3x3_m1.append( DilatedConv3DBlock(num_mid,num_out, kernel_size=(3, 3, 3), stride=stride, g=g, d=(dilation[i],dilation[i], dilation[i]),norm=norm) ) # It has not Dilated operation self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1,1,1), g=g,d=(1,1,1), norm=norm) # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm) # skip connection if num_in != num_out or stride != 1: if stride == 1: self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm) if stride == 2: self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm) def forward(self, x): x1 = self.conv1x1x1_in1(x) x2 = self.conv1x1x1_in2(x1) x3 = self.weight1*self.conv3x3x3_m1[0](x2) + self.weight2*self.conv3x3x3_m1[1](x2) + self.weight3*self.conv3x3x3_m1[2](x2) x4 = self.conv3x3x3_m2(x3) shortcut = x if hasattr(self, 'conv1x1x1_shortcut'): shortcut = self.conv1x1x1_shortcut(shortcut) if hasattr(self, 'conv2x2x2_shortcut'): shortcut = self.conv2x2x2_shortcut(shortcut) return x4 + shortcut class MFNet(nn.Module): # # [96] Flops: 13.361G & Params: 1.81M # [112] Flops: 16.759G & Params: 2.46M # [128] Flops: 20.611G & Params: 3.19M def __init__(self,in_channels, num_classes, n=32, channels=128, groups=16, norm='bn'): super(MFNet, self).__init__() # Entry flow self.encoder_block1 = nn.Conv3d(in_channels, n, kernel_size=3, padding=1, stride=2, bias=False)# H//2 self.encoder_block2 = nn.Sequential( MFunit(n, channels, g=groups, stride=2, norm=norm),# H//4 down MFunit(channels, channels, g=groups, stride=1, norm=norm), MFunit(channels, channels, g=groups, stride=1, norm=norm) ) # self.encoder_block3 = nn.Sequential( MFunit(channels, channels*2, g=groups, stride=2, norm=norm), # H//8 MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm), MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm) ) self.encoder_block4 = nn.Sequential(# H//8,channels*4 MFunit(channels*2, channels*3, g=groups, stride=2, norm=norm), # H//16 MFunit(channels*3, channels*3, g=groups, stride=1, norm=norm), MFunit(channels*3, channels*2, g=groups, stride=1, norm=norm), ) self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8 self.decoder_block1 = MFunit(channels*2+channels*2, channels*2, g=groups, stride=1, norm=norm) self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4 self.decoder_block2 = MFunit(channels*2 + channels, channels, g=groups, stride=1, norm=norm) self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2 self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm) self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0,stride=1,bias=False) # Initialization for m in self.modules(): if isinstance(m, nn.Conv3d): torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) # elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # Encoder x1 = self.encoder_block1(x)# H//2 down x2 = self.encoder_block2(x1)# H//4 down x3 = self.encoder_block3(x2)# H//8 down x4 = self.encoder_block4(x3) # H//16 # Decoder y1 = self.upsample1(x4)# H//8 y1 = torch.cat([x3,y1],dim=1) y1 = self.decoder_block1(y1) y2 = self.upsample2(y1)# H//4 y2 = torch.cat([x2,y2],dim=1) y2 = self.decoder_block2(y2) y3 = self.upsample3(y2)# H//2 y3 = torch.cat([x1,y3],dim=1) y3 = self.decoder_block3(y3) y4 = self.upsample4(y3) y4 = self.seg(y4) return y4 class DMFNet(MFNet): # softmax # [128] Flops: 27.045G & Params: 3.88M def __init__(self,in_channels, num_classes, n=32,channels=128, groups=16,norm='bn'): super(DMFNet, self).__init__(in_channels, num_classes, n, channels, groups, norm) self.encoder_block2 = nn.Sequential( DMFUnit(n, channels, g=groups, stride=2, norm=norm,dilation=[1,2,3]),# H//4 down DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]), # Dilated Conv 3 DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]) ) self.encoder_block3 = nn.Sequential( DMFUnit(channels, channels*2, g=groups, stride=2, norm=norm,dilation=[1,2,3]), # H//8 DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]),# Dilated Conv 3 DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]) ) def dmfnet(in_channels, num_classes): model = DMFNet(in_channels, num_classes) return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 64, 64, 64).long() # # model = dmfnet(1, 10) # model.train() # input = torch.rand(2, 1, 64, 64, 64) # output = model(input) # # loss_train = criterion(output, mask) # loss_train.backward() # # output = output.data.cpu().numpy() # print(output.shape) # print(loss_train) ================================================ FILE: models/networks_3d/espnet3d.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import math # from loss.loss_function import segmentation_loss from warnings import simplefilter simplefilter(action='ignore', category=UserWarning) class CBR(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) self.act = nn.ReLU(inplace=True) def forward(self, input): output = self.conv(input) output = self.bn(output) output = self.act(output) return output class CB(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) def forward(self, input): output = self.conv(input) output = self.bn(output) return output class C(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1, groups=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups) def forward(self, input): output = self.conv(input) return output class DownSamplerA(nn.Module): def __init__(self, nIn, nOut): super().__init__() self.conv = CBR(nIn, nOut, 3, 2) def forward(self, input): output = self.conv(input) return output class DownSamplerB(nn.Module): def __init__(self, nIn, nOut): super().__init__() k = 4 n = int(nOut/k) n1 = nOut - (k-1)*n self.c1 = nn.Sequential(CBR(nIn, n, 1, 1), C(n, n, 3, 2)) self.d1 = CDilated(n, n1, 3, 1, 1) self.d2 = CDilated(n, n, 3, 1, 2) self.d4 = CDilated(n, n, 3, 1, 3) self.d8 = CDilated(n, n, 3, 1, 4) self.bn = BR(nOut) def forward(self, input): output1 = self.c1(input) d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) add1 = d2 add2 = add1 + d4 add3 = add2 + d8 combine = torch.cat([d1, add1, add2, add3],1) if input.size() == combine.size(): combine = input + combine output = self.bn(combine) return output class BR(nn.Module): def __init__(self, nOut): super().__init__() self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) self.act = nn.ReLU(inplace=True) # nn.PReLU(nOut) def forward(self, input): output = self.bn(input) output = self.act(output) return output class CDilated(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): super().__init__() padding = int((kSize - 1) / 2) * d self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d, groups=groups) #self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) def forward(self, input): return self.conv(input) #return self.bn(output) class InputProjectionA(nn.Module): ''' This class projects the input image to the same spatial dimensions as the feature map. For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then this class will generate an output of 56x56x3 ''' def __init__(self, samplingTimes): ''' :param samplingTimes: The rate at which you want to down-sample the image ''' super().__init__() self.pool = nn.ModuleList() for i in range(0, samplingTimes): # pyramid-based approach for down-sampling self.pool.append(nn.AvgPool3d(3, stride=2, padding=1)) def forward(self, input): ''' :param input: Input RGB Image :return: down-sampled image (pyramid-based approach) ''' for pool in self.pool: input = pool(input) return input class DilatedParllelResidualBlockB1(nn.Module): # with k=4 def __init__(self, nIn, nOut, stride=1): super().__init__() k = 4 n = int(nOut / k) n1 = nOut - (k - 1) * n self.c1 = CBR(nIn, n, 1, 1) self.d1 = CDilated(n, n1, 3, stride, 1) self.d2 = CDilated(n, n, 3, stride, 1) self.d4 = CDilated(n, n, 3, stride, 2) self.d8 = CDilated(n, n, 3, stride, 2) self.bn = nn.BatchNorm3d(nOut) def forward(self, input): output1 = self.c1(input) d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) add1 = d2 add2 = add1 + d4 add3 = add2 + d8 combine = self.bn(torch.cat([d1, add1, add2, add3], 1)) if input.size() == combine.size(): combine = input + combine output = F.relu(combine, inplace=True) return output class ASPBlock(nn.Module): # with k=4 def __init__(self, nIn, nOut, stride=1): super().__init__() self.d1 = CB(nIn, nOut, 3, 1) self.d2 = CB(nIn, nOut, 5, 1) self.d4 = CB(nIn, nOut, 7, 1) self.d8 = CB(nIn, nOut, 9, 1) self.act = nn.ReLU(inplace=True) def forward(self, input): d1 = self.d1(input) d2 = self.d2(input) d3 = self.d4(input) d4 = self.d8(input) combine = d1 + d2 + d3 + d4 if input.size() == combine.size(): combine = input + combine output = self.act(combine) return output class UpSampler(nn.Module): ''' Up-sample the feature maps by 2 ''' def __init__(self, nIn, nOut): super().__init__() self.up = CBR(nIn, nOut, 3, 1) def forward(self, inp): return F.upsample(self.up(inp), mode='trilinear', scale_factor=2, align_corners=True) class PSPDec(nn.Module): ''' Inspired or Adapted from Pyramid Scene Network paper ''' def __init__(self, nIn, nOut, downSize): super().__init__() self.scale = downSize self.features = CBR(nIn, nOut, 3, 1) def forward(self, x): assert x.dim() == 5 inp_size = x.size() out_dim1, out_dim2, out_dim3 = int(inp_size[2] * self.scale), int(inp_size[3] * self.scale), int(inp_size[4] * self.scale) x_down = F.adaptive_avg_pool3d(x, output_size=(out_dim1, out_dim2, out_dim3)) return F.upsample(self.features(x_down), size=(inp_size[2], inp_size[3], inp_size[4]), mode='trilinear', align_corners=True) class ESPNet(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.input1 = InputProjectionA(1) self.input2 = InputProjectionA(1) initial = 16 # feature maps at level 1 config = [32, 128, 256, 256] # feature maps at level 2 and onwards reps = [2, 2, 3] ### ENCODER # all dimensions are listed with respect to an input of size 4 x 128 x 128 x 128 self.level0 = CBR(in_channels, initial, 7, 2) # initial x 64 x 64 x64 self.level1 = nn.ModuleList() for i in range(reps[0]): if i==0: self.level1.append(DilatedParllelResidualBlockB1(initial, config[0])) # config[0] x 64 x 64 x64 else: self.level1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x64 # downsample the feature maps self.level2 = DilatedParllelResidualBlockB1(config[0], config[1], stride=2) # config[1] x 32 x 32 x 32 self.level_2 = nn.ModuleList() for i in range(0, reps[1]): self.level_2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32 # downsample the feature maps self.level3_0 = DilatedParllelResidualBlockB1(config[1], config[2], stride=2) # config[2] x 16 x 16 x 16 self.level_3 = nn.ModuleList() for i in range(0, reps[2]): self.level_3.append(DilatedParllelResidualBlockB1(config[2], config[2])) # config[2] x 16 x 16 x 16 ### DECODER # upsample the feature maps self.up_l3_l2 = UpSampler(config[2], config[1]) # config[1] x 32 x 32 x 32 # Note the 2 in below line. You need this because you are concatenating feature maps from encoder # with upsampled feature maps self.merge_l2 = DilatedParllelResidualBlockB1(2 * config[1], config[1]) # config[1] x 32 x 32 x 32 self.dec_l2 = nn.ModuleList() for i in range(0, reps[0]): self.dec_l2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32 self.up_l2_l1 = UpSampler(config[1], config[0]) # config[0] x 64 x 64 x 64 # Note the 2 in below line. You need this because you are concatenating feature maps from encoder # with upsampled feature maps self.merge_l1 = DilatedParllelResidualBlockB1(2*config[0], config[0]) # config[0] x 64 x 64 x 64 self.dec_l1 = nn.ModuleList() for i in range(0, reps[0]): self.dec_l1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x 64 self.dec_l1.append(CBR(config[0], num_classes, 3, 1)) # classes x 64 x 64 x 64 # We use ESP block without reduction step because the number of input feature maps are very small (i.e. 4 in # our case) self.dec_l1.append(ASPBlock(num_classes, num_classes)) # Using PSP module to learn the representations at different scales self.pspModules = nn.ModuleList() scales = [0.2, 0.4, 0.6, 0.8] for sc in scales: self.pspModules.append(PSPDec(num_classes, num_classes, sc)) # Classifier self.classifier = self.classifier = nn.Sequential( CBR((len(scales) + 1) * num_classes, num_classes, 3, 1), ASPBlock(num_classes, num_classes), # classes x 64 x 64 x 64 nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True), # classes x 128 x 128 x 128 CBR(num_classes, num_classes, 7, 1), # classes x 128 x 128 x 128 C(num_classes, num_classes, 1, 1) # classes x 128 x 128 x 128 ) # for m in self.modules(): if isinstance(m, nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if isinstance(m, nn.ConvTranspose3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, input1, inp_res=(128, 128, 128), inpSt2=False): dim0 = input1.size(2) dim1 = input1.size(3) dim2 = input1.size(4) if self.training or inp_res is None: # input resolution should be divisible by 8 inp_res = (math.ceil(dim0 / 8) * 8, math.ceil(dim1 / 8) * 8, math.ceil(dim2 / 8) * 8) if inp_res: input1 = F.adaptive_avg_pool3d(input1, output_size=inp_res) out_l0 = self.level0(input1) for i, layer in enumerate(self.level1): #64 if i == 0: out_l1 = layer(out_l0) else: out_l1 = layer(out_l1) out_l2_down = self.level2(out_l1) #32 for i, layer in enumerate(self.level_2): if i == 0: out_l2 = layer(out_l2_down) else: out_l2 = layer(out_l2) del out_l2_down out_l3_down = self.level3_0(out_l2) #16 for i, layer in enumerate(self.level_3): if i == 0: out_l3 = layer(out_l3_down) else: out_l3 = layer(out_l3) del out_l3_down dec_l3_l2 = self.up_l3_l2(out_l3) merge_l2 = self.merge_l2(torch.cat([dec_l3_l2, out_l2], 1)) for i, layer in enumerate(self.dec_l2): if i == 0: dec_l2 = layer(merge_l2) else: dec_l2 = layer(dec_l2) dec_l2_l1 = self.up_l2_l1(dec_l2) merge_l1 = self.merge_l1(torch.cat([dec_l2_l1, out_l1], 1)) for i, layer in enumerate(self.dec_l1): if i == 0: dec_l1 = layer(merge_l1) else: dec_l1 = layer(dec_l1) psp_outs = dec_l1.clone() for layer in self.pspModules: out_psp = layer(dec_l1) psp_outs = torch.cat([psp_outs, out_psp], 1) decoded = self.classifier(psp_outs) return F.upsample(decoded, size=(dim0, dim1, dim2), mode='trilinear', align_corners=True) def espnet3d(in_channels, num_classes): model = ESPNet(in_channels, num_classes) return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # # mask = torch.ones(2, 96, 48, 96).long() # model = espnet3d(1, 10) # model.train() # input = torch.rand(2, 1, 96, 48, 96) # output = model(input) # loss_train = criterion(output, mask) # output = output.data.cpu().numpy() # loss_train.backward() # print(output.shape) # print(loss_train) ================================================ FILE: models/networks_3d/res_unet3d.py ================================================ import torch import torch.nn as nn import os from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class UNet(nn.Module): """ Implementations based on the Unet3D paper: https://arxiv.org/pdf/1706.00120.pdf """ def __init__(self, in_channels, n_classes, base_n_filter=8): super(UNet, self).__init__() self.in_channels = in_channels self.n_classes = n_classes self.base_n_filter = base_n_filter self.lrelu = nn.LeakyReLU() self.dropout3d = nn.Dropout3d(p=0.6) self.upsacle = nn.Upsample(scale_factor=2, mode='nearest') self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False) self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter) self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter) self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False) self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2) self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2) self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False) self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4) self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4) self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False) self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8) self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8) self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False) self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16) self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 8) self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8) self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16) self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False) self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 4) self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8) self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0, bias=False) self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 2) self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4) self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False) self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter) self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2) self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False) def conv_norm_lrelu(self, feat_in, feat_out): return nn.Sequential( nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm3d(feat_out), nn.LeakyReLU()) def norm_lrelu_conv(self, feat_in, feat_out): return nn.Sequential( nn.InstanceNorm3d(feat_in), nn.LeakyReLU(), nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) def lrelu_conv(self, feat_in, feat_out): return nn.Sequential( nn.LeakyReLU(), nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out): return nn.Sequential( nn.InstanceNorm3d(feat_in), nn.LeakyReLU(), nn.Upsample(scale_factor=2, mode='nearest'), # should be feat_in*2 or feat_in nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm3d(feat_out), nn.LeakyReLU()) def forward(self, x): # Level 1 context pathway out = self.conv3d_c1_1(x) residual_1 = out out = self.lrelu(out) out = self.conv3d_c1_2(out) out = self.dropout3d(out) out = self.lrelu_conv_c1(out) # Element Wise Summation out += residual_1 context_1 = self.lrelu(out) out = self.inorm3d_c1(out) out = self.lrelu(out) # Level 2 context pathway out = self.conv3d_c2(out) residual_2 = out out = self.norm_lrelu_conv_c2(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c2(out) out += residual_2 out = self.inorm3d_c2(out) out = self.lrelu(out) context_2 = out # Level 3 context pathway out = self.conv3d_c3(out) residual_3 = out out = self.norm_lrelu_conv_c3(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c3(out) out += residual_3 out = self.inorm3d_c3(out) out = self.lrelu(out) context_3 = out # Level 4 context pathway out = self.conv3d_c4(out) residual_4 = out out = self.norm_lrelu_conv_c4(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c4(out) out += residual_4 out = self.inorm3d_c4(out) out = self.lrelu(out) context_4 = out # Level 5 out = self.conv3d_c5(out) residual_5 = out out = self.norm_lrelu_conv_c5(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c5(out) out += residual_5 out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) out = self.conv3d_l0(out) out = self.inorm3d_l0(out) out = self.lrelu(out) # Level 1 localization pathway out = torch.cat([out, context_4], dim=1) out = self.conv_norm_lrelu_l1(out) out = self.conv3d_l1(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) # Level 2 localization pathway # print(out.shape) # print(context_3.shape) out = torch.cat([out, context_3], dim=1) out = self.conv_norm_lrelu_l2(out) ds2 = out out = self.conv3d_l2(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) # Level 3 localization pathway out = torch.cat([out, context_2], dim=1) out = self.conv_norm_lrelu_l3(out) ds3 = out out = self.conv3d_l3(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) # Level 4 localization pathway out = torch.cat([out, context_1], dim=1) out = self.conv_norm_lrelu_l4(out) out_pred = self.conv3d_l4(out) ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv) ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum) out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale seg_layer = out return seg_layer def res_unet3d(in_channels, num_classes): model = UNet(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = res_unet3d(1,10) # model.eval() # input = torch.rand(2, 1, 128, 128, 128) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/transbts.py ================================================ import torch import torch.nn as nn from torch.nn import init import torch.nn.functional as F from loss.loss_function import segmentation_loss def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) def normalization(planes, norm='gn'): if norm == 'bn': m = nn.BatchNorm3d(planes) elif norm == 'gn': m = nn.GroupNorm(8, planes) elif norm == 'in': m = nn.InstanceNorm3d(planes) else: raise ValueError('normalization type {} is not supported'.format(norm)) return m class InitConv(nn.Module): def __init__(self, in_channels=4, out_channels=16, dropout=0.2): super(InitConv, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) self.dropout = dropout def forward(self, x): y = self.conv(x) y = F.dropout3d(y, self.dropout) return y class EnBlock(nn.Module): def __init__(self, in_channels, norm='gn'): super(EnBlock, self).__init__() self.bn1 = normalization(in_channels, norm=norm) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.bn2 = normalization(in_channels, norm=norm) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x): x1 = self.bn1(x) x1 = self.relu1(x1) x1 = self.conv1(x1) y = self.bn2(x1) y = self.relu2(y) y = self.conv2(y) y = y + x return y class EnDown(nn.Module): def __init__(self, in_channels, out_channels): super(EnDown, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) def forward(self, x): y = self.conv(x) return y class Unet(nn.Module): def __init__(self, in_channels=4, base_channels=16): super(Unet, self).__init__() self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2) self.EnBlock1 = EnBlock(in_channels=base_channels) self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2) self.EnBlock2_1 = EnBlock(in_channels=base_channels*2) self.EnBlock2_2 = EnBlock(in_channels=base_channels*2) self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4) self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4) self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4) self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8) self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8) self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8) self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8) self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8) def forward(self, x): x = self.InitConv(x) # (1, 16, 128, 128, 128) x1_1 = self.EnBlock1(x) x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64) x2_1 = self.EnBlock2_1(x1_2) x2_1 = self.EnBlock2_2(x2_1) x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32) x3_1 = self.EnBlock3_1(x2_2) x3_1 = self.EnBlock3_2(x3_1) x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16) x4_1 = self.EnBlock4_1(x3_2) x4_2 = self.EnBlock4_2(x4_1) x4_3 = self.EnBlock4_3(x4_2) output = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16) return x1_1, x2_1, x3_1, output class FixedPositionalEncoding(nn.Module): def __init__(self, embedding_dim, max_length=512): super(FixedPositionalEncoding, self).__init__() pe = torch.zeros(max_length, embedding_dim) position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[: x.size(0), :] return x class LearnedPositionalEncoding(nn.Module): def __init__(self, max_position_embeddings, embedding_dim): super(LearnedPositionalEncoding, self).__init__() self.position_embeddings = nn.Parameter(torch.zeros(1, max_position_embeddings, embedding_dim)) #8x def forward(self, x): position_embeddings = self.position_embeddings return x + position_embeddings class IntermediateSequential(nn.Sequential): def __init__(self, *args, return_intermediate=True): super().__init__(*args) self.return_intermediate = return_intermediate def forward(self, input): if not self.return_intermediate: return super().forward(input) intermediate_outputs = {} output = input for name, module in self.named_children(): output = intermediate_outputs[name] = module(output) return output, intermediate_outputs class SelfAttention(nn.Module): def __init__( self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 ): super().__init__() self.num_heads = heads head_dim = dim // heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(dropout_rate) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(dropout_rate) def forward(self, x): B, N, C = x.shape qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)) q, k, v = (qkv[0], qkv[1], qkv[2]) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x): return self.fn(self.norm(x)) class PreNormDrop(nn.Module): def __init__(self, dim, dropout_rate, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(p=dropout_rate) self.fn = fn def forward(self, x): return self.dropout(self.fn(self.norm(x))) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout_rate): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(p=dropout_rate), nn.Linear(hidden_dim, dim), nn.Dropout(p=dropout_rate), ) def forward(self, x): return self.net(x) class TransformerModel(nn.Module): def __init__(self,dim,depth,heads,mlp_dim,dropout_rate=0.1,attn_dropout_rate=0.1): super().__init__() layers = [] for _ in range(depth): layers.extend([ Residual(PreNormDrop(dim,dropout_rate,SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate))), Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))]) # dim = dim / 2 self.net = IntermediateSequential(*layers) def forward(self, x): return self.net(x) class TransformerBTS(nn.Module): def __init__( self, img_dim, patch_dim, num_channels, embedding_dim, num_heads, num_layers, hidden_dim, dropout_rate=0.0, attn_dropout_rate=0.0, conv_patch_representation=True, positional_encoding_type="learned", ): super(TransformerBTS, self).__init__() assert embedding_dim % num_heads == 0 assert img_dim[0] % patch_dim == 0 assert img_dim[1] % patch_dim == 0 assert img_dim[2] % patch_dim == 0 self.img_dim = img_dim self.embedding_dim = embedding_dim self.num_heads = num_heads self.patch_dim = patch_dim self.num_channels = num_channels self.dropout_rate = dropout_rate self.attn_dropout_rate = attn_dropout_rate self.conv_patch_representation = conv_patch_representation self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim)) self.seq_length = self.num_patches self.flatten_dim = 128 * num_channels self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim) if positional_encoding_type == "learned": self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim) elif positional_encoding_type == "fixed": self.position_encoding = FixedPositionalEncoding(self.embedding_dim) self.pe_dropout = nn.Dropout(p=self.dropout_rate) self.transformer = TransformerModel(embedding_dim,num_layers,num_heads,hidden_dim,self.dropout_rate,self.attn_dropout_rate) self.pre_head_ln = nn.LayerNorm(embedding_dim) if self.conv_patch_representation: self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1) self.Unet = Unet(in_channels=num_channels, base_channels=16) self.bn = nn.BatchNorm3d(128) self.relu = nn.ReLU(inplace=True) def encode(self, x): if self.conv_patch_representation: # combine embedding with conv patch distribution x1_1, x2_1, x3_1, x = self.Unet(x) x = self.bn(x) x = self.relu(x) x = self.conv_x(x) x = x.permute(0, 2, 3, 4, 1).contiguous() x = x.view(x.size(0), -1, self.embedding_dim) else: x = self.Unet(x) x = self.bn(x) x = self.relu(x) x = ( x.unfold(2, 2, 2) .unfold(3, 2, 2) .unfold(4, 2, 2) .contiguous() ) x = x.view(x.size(0), x.size(1), -1, 8) x = x.permute(0, 2, 3, 1).contiguous() x = x.view(x.size(0), -1, self.flatten_dim) x = self.linear_encoding(x) x = self.position_encoding(x) x = self.pe_dropout(x) # apply transformer x, intmd_x = self.transformer(x) x = self.pre_head_ln(x) return x1_1, x2_1, x3_1, x, intmd_x def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]): x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x) decoder_output = self.decode(x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers) if auxillary_output_layers is not None: auxillary_outputs = {} for i in auxillary_output_layers: val = str(2 * i - 1) _key = 'Z' + str(i) auxillary_outputs[_key] = intmd_encoder_outputs[val] return decoder_output return decoder_output # def _get_padding(self, padding_type, kernel_size): # assert padding_type in ['SAME', 'VALID'] # if padding_type == 'SAME': # _list = [(k - 1) // 2 for k in kernel_size] # return tuple(_list) # return tuple(0 for _ in kernel_size) def _reshape_output(self, x): x = x.view( x.size(0), int(self.img_dim[0] / self.patch_dim), int(self.img_dim[1] / self.patch_dim), int(self.img_dim[2] / self.patch_dim), self.embedding_dim, ) x = x.permute(0, 4, 1, 2, 3).contiguous() return x class BTS(TransformerBTS): def __init__(self, in_channels, num_classes, img_shape=(128, 128, 128), patch_dim=8, embedding_dim=512, num_heads=8, num_layers=4, hidden_dim=4096, dropout_rate=0.1, attn_dropout_rate=0.1, conv_patch_representation=True, positional_encoding_type="learned"): super(BTS, self).__init__( img_dim=img_shape, patch_dim=patch_dim, num_channels=in_channels, embedding_dim=embedding_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout_rate=dropout_rate, attn_dropout_rate=attn_dropout_rate, conv_patch_representation=conv_patch_representation, positional_encoding_type=positional_encoding_type, ) self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim) self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4) self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8) self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8) self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16) self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16) self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32) self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32) self.endconv = nn.Conv3d(self.embedding_dim // 32, num_classes, kernel_size=1) def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]): assert intmd_layers is not None, "pass the intermediate layers for MLA" encoder_outputs = {} all_keys = [] for i in intmd_layers: val = str(2 * i - 1) _key = 'Z' + str(i) all_keys.append(_key) encoder_outputs[_key] = intmd_x[val] all_keys.reverse() x8 = encoder_outputs[all_keys[0]] x8 = self._reshape_output(x8) x8 = self.Enblock8_1(x8) x8 = self.Enblock8_2(x8) y4 = self.DeUp4(x8, x3_1) # (1, 64, 32, 32, 32) y4 = self.DeBlock4(y4) y3 = self.DeUp3(y4, x2_1) # (1, 32, 64, 64, 64) y3 = self.DeBlock3(y3) y2 = self.DeUp2(y3, x1_1) # (1, 16, 128, 128, 128) y2 = self.DeBlock2(y2) y = self.endconv(y2) # (1, 4, 128, 128, 128) return y class EnBlock1(nn.Module): def __init__(self, in_channels, ): super(EnBlock1, self).__init__() self.bn1 = nn.BatchNorm3d(in_channels // 4) self.relu1 = nn.ReLU(inplace=True) self.bn2 = nn.BatchNorm3d(in_channels // 4) self.relu2 = nn.ReLU(inplace=True) self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1) self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1) def forward(self, x): x1 = self.conv1(x) x1 = self.bn1(x1) x1 = self.relu1(x1) x1 = self.conv2(x1) x1 = self.bn2(x1) x1 = self.relu2(x1) return x1 class EnBlock2(nn.Module): def __init__(self, in_channels): super(EnBlock2, self).__init__() self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm3d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.bn2 = nn.BatchNorm3d(in_channels) self.relu2 = nn.ReLU(inplace=True) self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x): x1 = self.conv1(x) x1 = self.bn1(x1) x1 = self.relu1(x1) x1 = self.conv2(x1) x1 = self.bn2(x1) x1 = self.relu2(x1) x1 = x1 + x return x1 class DeUp_Cat(nn.Module): def __init__(self, in_channels, out_channels): super(DeUp_Cat, self).__init__() self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1) def forward(self, x, prev): x1 = self.conv1(x) y = self.conv2(x1) # y = y + prev y = torch.cat((prev, y), dim=1) y = self.conv3(y) return y class DeBlock(nn.Module): def __init__(self, in_channels): super(DeBlock, self).__init__() self.bn1 = nn.BatchNorm3d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm3d(in_channels) self.relu2 = nn.ReLU(inplace=True) def forward(self, x): x1 = self.conv1(x) x1 = self.bn1(x1) x1 = self.relu1(x1) x1 = self.conv2(x1) x1 = self.bn2(x1) x1 = self.relu2(x1) x1 = x1 + x return x1 def transbts(in_channels, num_classes, **kwargs): model = BTS(in_channels, num_classes, img_shape=kwargs['img_shape']) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 64, 96, 64).long() # model = transbts(1, 10, img_shape=(64, 96, 64)) # model.train() # input = torch.rand(2, 1, 64, 96, 64) # output = model(input) # loss_train = criterion(output, mask) # loss_train.backward() # output = output.data.cpu().numpy() # print(output.shape) # print(loss_train) ================================================ FILE: models/networks_3d/unet3d.py ================================================ import numpy as np from collections import OrderedDict import torch import torch.nn as nn from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class UNet3D(nn.Module): def __init__(self, in_channels=1, out_channels=3, init_features=64): """ Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 """ super(UNet3D, self).__init__() features = init_features self.encoder1 = UNet3D._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder2 = UNet3D._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose3d( features * 16, features * 8, kernel_size=2, stride=2 ) self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name="dec4") self.upconv3 = nn.ConvTranspose3d( features * 8, features * 4, kernel_size=2, stride=2 ) self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose3d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose3d( features * 2, features, kernel_size=2, stride=2 ) self.decoder1 = UNet3D._block(features * 2, features, name="dec1") self.conv = nn.Conv3d( in_channels=features, out_channels=out_channels, kernel_size=1 ) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) outputs = self.conv(dec1) return outputs @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) class UNet3D_min(nn.Module): def __init__(self, in_channels=1, out_channels=3, init_features=32): """ Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 """ super(UNet3D_min, self).__init__() features = init_features self.encoder1 = UNet3D._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder2 = UNet3D._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose3d( features * 16, features * 8, kernel_size=2, stride=2 ) self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name="dec4") self.upconv3 = nn.ConvTranspose3d( features * 8, features * 4, kernel_size=2, stride=2 ) self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose3d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose3d( features * 2, features, kernel_size=2, stride=2 ) self.decoder1 = UNet3D._block(features * 2, features, name="dec1") self.conv = nn.Conv3d( in_channels=features, out_channels=out_channels, kernel_size=1 ) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) outputs = self.conv(dec1) return outputs @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) def unet3d(in_channels, num_classes): model = UNet3D(in_channels, num_classes) init_weights(model, 'kaiming') return model def unet3d_min(in_channels, num_classes): model = UNet3D_min(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet3d(1,10) # model.eval() # input = torch.rand(2, 1, 128, 128, 128) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/unet3d_cct.py ================================================ import numpy as np from collections import OrderedDict import torch import torch.nn as nn from torch.nn import init from torch.distributions.uniform import Uniform def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class FeatureNoise(nn.Module): def __init__(self, uniform_range=0.3): super(FeatureNoise, self).__init__() self.uni_dist = Uniform(-uniform_range, uniform_range) def feature_based_noise(self, x): noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) x_noise = x.mul(noise_vector) + x return x_noise def forward(self, x): x = self.feature_based_noise(x) return x def Dropout(x, p=0.3): x = torch.nn.functional.dropout(x, p) return x def FeatureDropout(x): attention = torch.mean(x, dim=1, keepdim=True) max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) threshold = max_val * np.random.uniform(0.7, 0.9) threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention) drop_mask = (attention < threshold).float() x = x.mul(drop_mask) return x class Decoder(nn.Module): def __init__(self, features, out_channels): super(Decoder, self).__init__() self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2) self.decoder4 = Decoder._block((features * 8) * 2, features * 8, name="dec4") self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2) self.decoder3 = Decoder._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2) self.decoder2 = Decoder._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2) self.decoder1 = Decoder._block(features * 2, features, name="dec1") self.conv = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1) def forward(self, x5, x4, x3, x2, x1): dec4 = self.upconv4(x5) dec4 = torch.cat((dec4, x4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, x3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, x2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, x1), dim=1) dec1 = self.decoder1(dec1) outputs = self.conv(dec1) return outputs @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) class UNet3D_CCT(nn.Module): def __init__(self, in_channels=1, out_channels=3, init_features=64): """ Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 """ super(UNet3D_CCT, self).__init__() features = init_features self.encoder1 = UNet3D_CCT._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder2 = UNet3D_CCT._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name="bottleneck") self.main_decoder = Decoder(features, out_channels) self.aux_decoder1 = Decoder(features, out_channels) self.aux_decoder2 = Decoder(features, out_channels) self.aux_decoder3 = Decoder(features, out_channels) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1) aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1)) aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1)) aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1)) return main_seg, aux_seg1, aux_seg2, aux_seg3 @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) class UNet3D_CCT_min(nn.Module): def __init__(self, in_channels=1, out_channels=3, init_features=32): """ Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 """ super(UNet3D_CCT_min, self).__init__() features = init_features self.encoder1 = UNet3D_CCT._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder2 = UNet3D_CCT._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name="bottleneck") self.main_decoder = Decoder(features, out_channels) self.aux_decoder1 = Decoder(features, out_channels) self.aux_decoder2 = Decoder(features, out_channels) self.aux_decoder3 = Decoder(features, out_channels) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1) aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1)) aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1)) aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1)) return main_seg, aux_seg1, aux_seg2, aux_seg3 @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) def unet3d_cct(in_channels, num_classes): model = UNet3D_CCT(in_channels, num_classes) init_weights(model, 'kaiming') return model def unet3d_cct_min(in_channels, num_classes): model = UNet3D_CCT_min(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet3d_cct(1,10) # model.eval() # input = torch.rand(2, 1, 128, 128, 128) # output, aux_output1, aux_output2, aux_output3 = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/unet3d_dtc.py ================================================ import numpy as np from collections import OrderedDict import torch import torch.nn as nn from torch.nn import init # from loss.loss_function import segmentation_loss def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class UNet3D_DTC(nn.Module): def __init__(self, in_channels=1, out_channels=3, init_features=64): """ Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 """ super(UNet3D_DTC, self).__init__() features = init_features self.encoder1 = UNet3D_DTC._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder2 = UNet3D_DTC._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder3 = UNet3D_DTC._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder4 = UNet3D_DTC._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2) self.bottleneck = UNet3D_DTC._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2) self.decoder4 = UNet3D_DTC._block((features * 8) * 2 , features * 8, name="dec4") self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2) self.decoder3 = UNet3D_DTC._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2) self.decoder2 = UNet3D_DTC._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2) self.decoder1 = UNet3D_DTC._block(features * 2, features, name="dec1") self.out_sdf = nn.Sequential( nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1), nn.Tanh() ) self.out_seg = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) out_sdf = self.out_sdf(dec1) out_seg = self.out_seg(dec1) return out_sdf, out_seg @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv3d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm1", nn.BatchNorm3d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv3d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=True, ), ), (name + "norm2", nn.BatchNorm3d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) ) def unet3d_dtc(in_channels, num_classes): model = UNet3D_DTC(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 96, 96, 96).long() # model = unet3d_dtc(1, 10) # model.train() # input1 = torch.rand(2,1,96,96,96) # out_sdf, out_seg = model(input1) # loss_train = criterion(out_sdf, mask) # loss_train.backward() # # print(output) # print(out_sdf.data.cpu().numpy().shape) # print(out_seg.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: models/networks_3d/unet3d_urpc.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class UnetConv3(nn.Module): def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): super(UnetConv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.InstanceNorm3d(out_size), nn.ReLU(inplace=True),) self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.InstanceNorm3d(out_size), nn.ReLU(inplace=True),) else: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True),) self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True),) # initialise the blocks # for m in self.children(): # init_weights(m, init_type='kaiming') def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) return outputs class UnetUp3(nn.Module): def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): super(UnetUp3, self).__init__() if is_deconv: self.conv = UnetConv3(in_size, out_size, is_batchnorm) self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) else: self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear', align_corners=True) # initialise the blocks # for m in self.children(): # if m.__class__.__name__.find('UnetConv3') != -1: continue # init_weights(m, init_type='kaiming') def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset = outputs2.size()[2] - inputs1.size()[2] padding = 2 * [offset // 2, offset // 2, 0] outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1)) class UnetUp3_CT(nn.Module): def __init__(self, in_size, out_size, is_batchnorm=True): super(UnetUp3_CT, self).__init__() self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=True) # initialise the blocks # for m in self.children(): # if m.__class__.__name__.find('UnetConv3') != -1: continue # init_weights(m, init_type='kaiming') def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset = outputs2.size()[2] - inputs1.size()[2] padding = 2 * [offset // 2, offset // 2, 0] outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1)) class UnetDsv3(nn.Module): def __init__(self, in_size, out_size, scale_factor): super(UnetDsv3, self).__init__() self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True), ) def forward(self, input): return self.dsv(input) class unet_3D_dv_semi(nn.Module): def __init__(self, in_channels=3, n_classes=21, feature_scale=4, is_deconv=True, is_batchnorm=True): super(unet_3D_dv_semi, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) # upsampling self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) # deep supervision self.dsv4 = UnetDsv3( in_size=filters[3], out_size=n_classes, scale_factor=8) self.dsv3 = UnetDsv3( in_size=filters[2], out_size=n_classes, scale_factor=4) self.dsv2 = UnetDsv3( in_size=filters[1], out_size=n_classes, scale_factor=2) self.dsv1 = nn.Conv3d( in_channels=filters[0], out_channels=n_classes, kernel_size=1) self.dropout1 = nn.Dropout3d(p=0.5) self.dropout2 = nn.Dropout3d(p=0.3) self.dropout3 = nn.Dropout3d(p=0.2) self.dropout4 = nn.Dropout3d(p=0.1) # initialise weights # for m in self.modules(): # if isinstance(m, nn.Conv3d): # init_weights(m, init_type='kaiming') # elif isinstance(m, nn.BatchNorm3d): # init_weights(m, init_type='kaiming') def forward(self, inputs): conv1 = self.conv1(inputs) maxpool1 = self.maxpool1(conv1) conv2 = self.conv2(maxpool1) maxpool2 = self.maxpool2(conv2) conv3 = self.conv3(maxpool2) maxpool3 = self.maxpool3(conv3) conv4 = self.conv4(maxpool3) maxpool4 = self.maxpool4(conv4) center = self.center(maxpool4) up4 = self.up_concat4(conv4, center) up4 = self.dropout1(up4) up3 = self.up_concat3(conv3, up4) up3 = self.dropout2(up3) up2 = self.up_concat2(conv2, up3) up2 = self.dropout3(up2) up1 = self.up_concat1(conv1, up2) up1 = self.dropout4(up1) # Deep Supervision dsv4 = self.dsv4(up4) dsv3 = self.dsv3(up3) dsv2 = self.dsv2(up2) dsv1 = self.dsv1(up1) return dsv1, dsv2, dsv3, dsv4 @staticmethod def apply_argmax_softmax(pred): log_p = F.softmax(pred, dim=1) return log_p def unet3d_urpc(in_channels, num_classes): model = unet_3D_dv_semi(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unet3d_urpc(1,10) # model.eval() # input = torch.rand(2, 1, 128, 128, 128) # output, output2, output3, output4 = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/unetr.py ================================================ import copy import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class SingleDeconv3DBlock(nn.Module): def __init__(self, in_planes, out_planes): super().__init__() self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) def forward(self, x): return self.block(x) class SingleConv3DBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size): super().__init__() self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)) def forward(self, x): return self.block(x) class Conv3DBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size=3): super().__init__() self.block = nn.Sequential( SingleConv3DBlock(in_planes, out_planes, kernel_size), nn.BatchNorm3d(out_planes), nn.ReLU(True) ) def forward(self, x): return self.block(x) class Deconv3DBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size=3): super().__init__() self.block = nn.Sequential( SingleDeconv3DBlock(in_planes, out_planes), SingleConv3DBlock(out_planes, out_planes, kernel_size), nn.BatchNorm3d(out_planes), nn.ReLU(True) ) def forward(self, x): return self.block(x) class SelfAttention(nn.Module): def __init__(self, num_heads, embed_dim, dropout): super().__init__() self.num_attention_heads = num_heads self.attention_head_size = int(embed_dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(embed_dim, self.all_head_size) self.key = nn.Linear(embed_dim, self.all_head_size) self.value = nn.Linear(embed_dim, self.all_head_size) self.out = nn.Linear(embed_dim, embed_dim) self.attn_dropout = nn.Dropout(dropout) self.proj_dropout = nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) self.vis = False def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) # weights = attention_probs if self.vis else None attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) # return attention_output, weights return attention_output # class Mlp(nn.Module): # def __init__(self, in_features, act_layer=nn.GELU, drop=0.): # super().__init__() # self.fc1 = nn.Linear(in_features, in_features) # self.act = act_layer() # self.drop = nn.Dropout(drop) # # def forward(self, x): # x = self.fc1() # x = self.act(x) # x = self.drop(x) # return x class PositionwiseFeedForward(nn.Module): def __init__(self, d_model=786, d_ff=2048, dropout=0.1): super().__init__() # Torch linears have a `b` by default. self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w_2(self.dropout(F.relu(self.w_1(x)))) class Embeddings(nn.Module): def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout): super().__init__() self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) self.patch_size = patch_size self.embed_dim = embed_dim self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim)) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.patch_embeddings(x) x = x.flatten(2) x = x.transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size): super().__init__() self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-6) self.mlp_norm = nn.LayerNorm(embed_dim, eps=1e-6) self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) self.mlp = PositionwiseFeedForward(embed_dim, 2048) self.attn = SelfAttention(num_heads, embed_dim, dropout) def forward(self, x): h = x x = self.attention_norm(x) # x, weights = self.attn(x) x = self.attn(x) x = x + h h = x x = self.mlp_norm(x) x = self.mlp(x) x = x + h # return x, weights return x class Transformer(nn.Module): def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_heads, num_layers, dropout, extract_layers): super().__init__() self.embeddings = Embeddings(input_dim, embed_dim, cube_size, patch_size, dropout) self.layer = nn.ModuleList() self.encoder_norm = nn.LayerNorm(embed_dim, eps=1e-6) self.extract_layers = extract_layers for _ in range(num_layers): layer = TransformerBlock(embed_dim, num_heads, dropout, cube_size, patch_size) self.layer.append(copy.deepcopy(layer)) def forward(self, x): extract_layers = [] hidden_states = self.embeddings(x) for depth, layer_block in enumerate(self.layer): # hidden_states, _ = layer_block(hidden_states) hidden_states = layer_block(hidden_states) if depth + 1 in self.extract_layers: extract_layers.append(hidden_states) return extract_layers class UNETR(nn.Module): def __init__(self, input_dim=4, output_dim=3, img_shape=(128, 128, 128), embed_dim=768, patch_size=16, num_heads=12, dropout=0.1): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.embed_dim = embed_dim self.img_shape = img_shape self.patch_size = patch_size self.num_heads = num_heads self.dropout = dropout self.num_layers = 12 self.ext_layers = [3, 6, 9, 12] self.patch_dim = [int(x / patch_size) for x in img_shape] # Transformer Encoder self.transformer = \ Transformer( input_dim, embed_dim, img_shape, patch_size, num_heads, self.num_layers, dropout, self.ext_layers ) # U-Net Decoder self.decoder0 = \ nn.Sequential( Conv3DBlock(input_dim, 32, 3), Conv3DBlock(32, 64, 3) ) self.decoder3 = \ nn.Sequential( Deconv3DBlock(embed_dim, 512), Deconv3DBlock(512, 256), Deconv3DBlock(256, 128) ) self.decoder6 = \ nn.Sequential( Deconv3DBlock(embed_dim, 512), Deconv3DBlock(512, 256), ) self.decoder9 = \ Deconv3DBlock(embed_dim, 512) self.decoder12_upsampler = \ SingleDeconv3DBlock(embed_dim, 512) self.decoder9_upsampler = \ nn.Sequential( Conv3DBlock(1024, 512), Conv3DBlock(512, 512), Conv3DBlock(512, 512), SingleDeconv3DBlock(512, 256) ) self.decoder6_upsampler = \ nn.Sequential( Conv3DBlock(512, 256), Conv3DBlock(256, 256), SingleDeconv3DBlock(256, 128) ) self.decoder3_upsampler = \ nn.Sequential( Conv3DBlock(256, 128), Conv3DBlock(128, 128), SingleDeconv3DBlock(128, 64) ) self.decoder0_header = \ nn.Sequential( Conv3DBlock(128, 64), Conv3DBlock(64, 64), SingleConv3DBlock(64, output_dim, 1) ) def forward(self, x): z = self.transformer(x) z0, z3, z6, z9, z12 = x, *z z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim) z12 = self.decoder12_upsampler(z12) z9 = self.decoder9(z9) z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1)) z6 = self.decoder6(z6) z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1)) z3 = self.decoder3(z3) z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1)) z0 = self.decoder0(z0) output = self.decoder0_header(torch.cat([z0, z3], dim=1)) return output def unertr(in_channels, num_classes, **kwargs): model = UNETR(in_channels, num_classes, img_shape=kwargs['img_shape']) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = unertr(1,10, img_shape=(96, 96, 96)) # model.eval() # input = torch.rand(2, 1, 96, 96, 96) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/vnet.py ================================================ import torch import torch.nn as nn import os import numpy as np from collections import OrderedDict from torch.nn import init def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) def passthrough(x, **kwargs): return x def ELUCons(elu, nchan): if elu: return nn.ELU(inplace=True) else: return nn.PReLU(nchan) class LUConv(nn.Module): def __init__(self, nchan, elu): super(LUConv, self).__init__() self.relu1 = ELUCons(elu, nchan) self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(nchan) def forward(self, x): out = self.relu1(self.bn1(self.conv1(x))) return out def _make_nConv(nchan, depth, elu): layers = [] for _ in range(depth): layers.append(LUConv(nchan, elu)) return nn.Sequential(*layers) class InputTransition(nn.Module): def __init__(self, in_channels, elu): super(InputTransition, self).__init__() self.num_features = 16 self.in_channels = in_channels self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(self.num_features) self.relu1 = ELUCons(elu, self.num_features) def forward(self, x): out = self.conv1(x) repeat_rate = int(self.num_features / self.in_channels) out = self.bn1(out) x16 = x.repeat(1, repeat_rate, 1, 1, 1) return self.relu1(torch.add(out, x16)) class DownTransition(nn.Module): def __init__(self, inChans, nConvs, elu, dropout=False): super(DownTransition, self).__init__() outChans = 2 * inChans self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans) self.do1 = passthrough self.relu1 = ELUCons(elu, outChans) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x): down = self.relu1(self.bn1(self.down_conv(x))) out = self.do1(down) out = self.ops(out) out = self.relu2(torch.add(out, down)) return out class UpTransition(nn.Module): def __init__(self, inChans, outChans, nConvs, elu, dropout=False): super(UpTransition, self).__init__() self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans // 2) self.do1 = passthrough self.do2 = nn.Dropout3d() self.relu1 = ELUCons(elu, outChans // 2) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x, skipx): out = self.do1(x) skipxdo = self.do2(skipx) out = self.relu1(self.bn1(self.up_conv(out))) xcat = torch.cat((out, skipxdo), 1) out = self.ops(xcat) out = self.relu2(torch.add(out, xcat)) return out class OutputTransition(nn.Module): def __init__(self, in_channels, classes, elu): super(OutputTransition, self).__init__() self.classes = classes self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(classes) self.conv2 = nn.Conv3d(classes, classes, kernel_size=1) self.relu1 = ELUCons(elu, classes) def forward(self, x): # convolve 32 down to channels as the desired classes out = self.relu1(self.bn1(self.conv1(x))) out = self.conv2(out) return out class VNet(nn.Module): """ Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797 """ def __init__(self, in_channels=1, classes=1, elu=True): super(VNet, self).__init__() self.classes = classes self.in_channels = in_channels self.in_tr = InputTransition(in_channels, elu=elu) self.down_tr32 = DownTransition(16, 1, elu) self.down_tr64 = DownTransition(32, 2, elu) self.down_tr128 = DownTransition(64, 3, elu, dropout=False) self.down_tr256 = DownTransition(128, 2, elu, dropout=False) self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False) self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False) self.up_tr64 = UpTransition(128, 64, 1, elu) self.up_tr32 = UpTransition(64, 32, 1, elu) self.out_tr = OutputTransition(32, classes, elu) def forward(self, x): out16 = self.in_tr(x) out32 = self.down_tr32(out16) out64 = self.down_tr64(out32) out128 = self.down_tr128(out64) out256 = self.down_tr256(out128) out = self.up_tr256(out256, out128) out = self.up_tr128(out, out64) out = self.up_tr64(out, out32) out = self.up_tr32(out, out16) out = self.out_tr(out) return out def vnet(in_channels, num_classes): model = VNet(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # model = vnet(1,10) # model.eval() # input = torch.rand(2, 1, 128, 128, 128) # output = model(input) # output = output.data.cpu().numpy() # # print(output) # print(output.shape) ================================================ FILE: models/networks_3d/vnet_cct.py ================================================ import torch import torch.nn as nn import os import numpy as np from collections import OrderedDict from torch.nn import init from torch.distributions.uniform import Uniform # from loss.loss_function import segmentation_loss def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) class FeatureNoise(nn.Module): def __init__(self, uniform_range=0.3): super(FeatureNoise, self).__init__() self.uni_dist = Uniform(-uniform_range, uniform_range) def feature_based_noise(self, x): noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) x_noise = x.mul(noise_vector) + x return x_noise def forward(self, x): x = self.feature_based_noise(x) return x def Dropout(x, p=0.3): x = torch.nn.functional.dropout(x, p) return x def FeatureDropout(x): attention = torch.mean(x, dim=1, keepdim=True) max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) threshold = max_val * np.random.uniform(0.7, 0.9) threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention) drop_mask = (attention < threshold).float() x = x.mul(drop_mask) return x def passthrough(x, **kwargs): return x def ELUCons(elu, nchan): if elu: return nn.ELU(inplace=True) else: return nn.PReLU(nchan) class LUConv(nn.Module): def __init__(self, nchan, elu): super(LUConv, self).__init__() self.relu1 = ELUCons(elu, nchan) self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(nchan) def forward(self, x): out = self.relu1(self.bn1(self.conv1(x))) return out def _make_nConv(nchan, depth, elu): layers = [] for _ in range(depth): layers.append(LUConv(nchan, elu)) return nn.Sequential(*layers) class InputTransition(nn.Module): def __init__(self, in_channels, elu): super(InputTransition, self).__init__() self.num_features = 16 self.in_channels = in_channels self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(self.num_features) self.relu1 = ELUCons(elu, self.num_features) def forward(self, x): out = self.conv1(x) repeat_rate = int(self.num_features / self.in_channels) out = self.bn1(out) x16 = x.repeat(1, repeat_rate, 1, 1, 1) return self.relu1(torch.add(out, x16)) class DownTransition(nn.Module): def __init__(self, inChans, nConvs, elu, dropout=False): super(DownTransition, self).__init__() outChans = 2 * inChans self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans) self.do1 = passthrough self.relu1 = ELUCons(elu, outChans) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x): down = self.relu1(self.bn1(self.down_conv(x))) out = self.do1(down) out = self.ops(out) out = self.relu2(torch.add(out, down)) return out class UpTransition(nn.Module): def __init__(self, inChans, outChans, nConvs, elu, dropout=False): super(UpTransition, self).__init__() self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans // 2) self.do1 = passthrough self.do2 = nn.Dropout3d() self.relu1 = ELUCons(elu, outChans // 2) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x, skipx): out = self.do1(x) skipxdo = self.do2(skipx) out = self.relu1(self.bn1(self.up_conv(out))) xcat = torch.cat((out, skipxdo), 1) out = self.ops(xcat) out = self.relu2(torch.add(out, xcat)) return out class OutputTransition(nn.Module): def __init__(self, in_channels, classes, elu): super(OutputTransition, self).__init__() self.classes = classes self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(classes) self.conv2 = nn.Conv3d(classes, classes, kernel_size=1) self.relu1 = ELUCons(elu, classes) def forward(self, x): # convolve 32 down to channels as the desired classes out = self.relu1(self.bn1(self.conv1(x))) out = self.conv2(out) return out class Decoder(nn.Module): def __init__(self, out_channels, elu): super(Decoder, self).__init__() self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False) self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False) self.up_tr64 = UpTransition(128, 64, 1, elu) self.up_tr32 = UpTransition(64, 32, 1, elu) self.out_tr = OutputTransition(32, out_channels, elu) def forward(self, out256, out128, out64, out32, out16): out = self.up_tr256(out256, out128) out = self.up_tr128(out, out64) out = self.up_tr64(out, out32) out = self.up_tr32(out, out16) out = self.out_tr(out) return out class VNet_CCT(nn.Module): """ Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797 """ def __init__(self, in_channels=1, classes=1, elu=True): super(VNet_CCT, self).__init__() self.classes = classes self.in_channels = in_channels self.in_tr = InputTransition(in_channels, elu=elu) self.down_tr32 = DownTransition(16, 1, elu) self.down_tr64 = DownTransition(32, 2, elu) self.down_tr128 = DownTransition(64, 3, elu, dropout=False) self.down_tr256 = DownTransition(128, 2, elu, dropout=False) self.main_decoder = Decoder(classes, elu) self.aux_decoder1 = Decoder(classes, elu) self.aux_decoder2 = Decoder(classes, elu) self.aux_decoder3 = Decoder(classes, elu) def forward(self, x): out16 = self.in_tr(x) out32 = self.down_tr32(out16) out64 = self.down_tr64(out32) out128 = self.down_tr128(out64) out256 = self.down_tr256(out128) main_seg = self.main_decoder(out256, out128, out64, out32, out16) aux_seg1 = self.main_decoder(FeatureNoise()(out256), FeatureNoise()(out128), FeatureNoise()(out64), FeatureNoise()(out32), FeatureNoise()(out16)) aux_seg2 = self.main_decoder(Dropout(out256), Dropout(out128), Dropout(out64), Dropout(out32), Dropout(out16)) aux_seg3 = self.main_decoder(FeatureDropout(out256), FeatureDropout(out128), FeatureDropout(out64), FeatureDropout(out32), FeatureDropout(out16)) return main_seg, aux_seg1, aux_seg2, aux_seg3 def vnet_cct(in_channels, num_classes): model = VNet_CCT(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 64, 96, 64).long() # model = vnet_cct(1, 10) # model.train() # input = torch.rand(2, 1, 64, 96, 64) # output, output1, output2, output3 = model(input) # loss_train = criterion(output, mask) # loss_train.backward() # output = output.data.cpu().numpy() # print(output.shape) # print(loss_train) ================================================ FILE: models/networks_3d/vnet_dtc.py ================================================ import torch import torch.nn as nn import os import numpy as np from collections import OrderedDict from torch.nn import init # from loss.loss_function import segmentation_loss def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) def passthrough(x, **kwargs): return x def ELUCons(elu, nchan): if elu: return nn.ELU(inplace=True) else: return nn.PReLU(nchan) class LUConv(nn.Module): def __init__(self, nchan, elu): super(LUConv, self).__init__() self.relu1 = ELUCons(elu, nchan) self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(nchan) def forward(self, x): out = self.relu1(self.bn1(self.conv1(x))) return out def _make_nConv(nchan, depth, elu): layers = [] for _ in range(depth): layers.append(LUConv(nchan, elu)) return nn.Sequential(*layers) class InputTransition(nn.Module): def __init__(self, in_channels, elu): super(InputTransition, self).__init__() self.num_features = 16 self.in_channels = in_channels self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(self.num_features) self.relu1 = ELUCons(elu, self.num_features) def forward(self, x): out = self.conv1(x) repeat_rate = int(self.num_features / self.in_channels) out = self.bn1(out) x16 = x.repeat(1, repeat_rate, 1, 1, 1) return self.relu1(torch.add(out, x16)) class DownTransition(nn.Module): def __init__(self, inChans, nConvs, elu, dropout=False): super(DownTransition, self).__init__() outChans = 2 * inChans self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans) self.do1 = passthrough self.relu1 = ELUCons(elu, outChans) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x): down = self.relu1(self.bn1(self.down_conv(x))) out = self.do1(down) out = self.ops(out) out = self.relu2(torch.add(out, down)) return out class UpTransition(nn.Module): def __init__(self, inChans, outChans, nConvs, elu, dropout=False): super(UpTransition, self).__init__() self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) self.bn1 = torch.nn.BatchNorm3d(outChans // 2) self.do1 = passthrough self.do2 = nn.Dropout3d() self.relu1 = ELUCons(elu, outChans // 2) self.relu2 = ELUCons(elu, outChans) if dropout: self.do1 = nn.Dropout3d() self.ops = _make_nConv(outChans, nConvs, elu) def forward(self, x, skipx): out = self.do1(x) skipxdo = self.do2(skipx) out = self.relu1(self.bn1(self.up_conv(out))) xcat = torch.cat((out, skipxdo), 1) out = self.ops(xcat) out = self.relu2(torch.add(out, xcat)) return out class OutputTransition(nn.Module): def __init__(self, in_channels, classes, elu): super(OutputTransition, self).__init__() self.classes = classes self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2) self.bn1 = torch.nn.BatchNorm3d(classes) self.conv2 = nn.Conv3d(classes, classes, kernel_size=1) self.relu1 = ELUCons(elu, classes) def forward(self, x): # convolve 32 down to channels as the desired classes out = self.relu1(self.bn1(self.conv1(x))) out = self.conv2(out) return out class VNet_DTC(nn.Module): """ Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797 """ def __init__(self, in_channels=1, classes=1, elu=True): super(VNet_DTC, self).__init__() self.classes = classes self.in_channels = in_channels self.in_tr = InputTransition(in_channels, elu=elu) self.down_tr32 = DownTransition(16, 1, elu) self.down_tr64 = DownTransition(32, 2, elu) self.down_tr128 = DownTransition(64, 3, elu, dropout=False) self.down_tr256 = DownTransition(128, 2, elu, dropout=False) self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False) self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False) self.up_tr64 = UpTransition(128, 64, 1, elu) self.up_tr32 = UpTransition(64, 32, 1, elu) self.out_tr = OutputTransition(32, 16, elu) self.out_sdf = nn.Sequential( nn.Conv3d(16, classes, 1, padding=0), nn.Tanh() ) self.out_seg = nn.Conv3d(16, classes, 1, padding=0) def forward(self, x): out16 = self.in_tr(x) out32 = self.down_tr32(out16) out64 = self.down_tr64(out32) out128 = self.down_tr128(out64) out256 = self.down_tr256(out128) out = self.up_tr256(out256, out128) out = self.up_tr128(out, out64) out = self.up_tr64(out, out32) out = self.up_tr32(out, out16) out = self.out_tr(out) out_sdf = self.out_sdf(out) out_seg = self.out_seg(out) return out_sdf, out_seg def vnet_dtc(in_channels, num_classes): model = VNet_DTC(in_channels, num_classes) init_weights(model, 'kaiming') return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 96, 96, 96).long() # model = vnet_dtc(1, 10) # model.train() # input1 = torch.rand(2,1,96,96,96) # out_sdf, out_seg = model(input1) # loss_train = criterion(out_sdf, mask) # loss_train.backward() # # print(output) # print(out_sdf.data.cpu().numpy().shape) # print(out_seg.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: models/networks_3d/xnet3d.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init import functools from torch.distributions.uniform import Uniform import numpy as np # from loss.loss_function import segmentation_loss # BN BatchNorm3d = nn.InstanceNorm3d BN_MOMENTUM = 0.1 # BN_MOMENTUM = 0.01 # AF relu_inplace = True ActivationFunction = nn.ReLU def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) class up_conv(nn.Module): def __init__(self, ch_in, ch_out): super(up_conv, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True), nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), BatchNorm3d(ch_out, momentum=BN_MOMENTUM), ActivationFunction(inplace=relu_inplace) ) def forward(self, x): x = self.up(x) return x class down_conv(nn.Module): def __init__(self, ch_in, ch_out): super(down_conv, self).__init__() self.down = nn.Sequential( nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False), BatchNorm3d(ch_out, momentum=BN_MOMENTUM), ActivationFunction(inplace=relu_inplace) ) def forward(self, x): x = self.down(x) return x class same_conv(nn.Module): def __init__(self, ch_in, ch_out): super(same_conv, self).__init__() self.same = nn.Sequential( nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False), BatchNorm3d(ch_out, momentum=BN_MOMENTUM), ActivationFunction(inplace=relu_inplace) ) def forward(self, x): x = self.same(x) return x class transition_conv(nn.Module): def __init__(self, ch_in, ch_out): super(transition_conv, self).__init__() self.transition = nn.Sequential( nn.Conv3d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False), BatchNorm3d(ch_out, momentum=BN_MOMENTUM), ActivationFunction(inplace=relu_inplace) ) def forward(self, x): x = self.transition(x) return x class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = BatchNorm3d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM) self.relu = ActivationFunction(inplace=relu_inplace) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) # out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out = self.bn2(out) + identity out = self.relu(out) return out class DoubleBasicBlock(nn.Module): def __init__(self, inplanes, planes, downsample=None): super(DoubleBasicBlock, self).__init__() self.DBB = nn.Sequential( BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample), BasicBlock(inplanes=planes, planes=planes) ) def forward(self, x): out = self.DBB(x) return out class XNet3D(nn.Module): def __init__(self, in_channels, num_classes): super(XNet3D, self).__init__() # l1c, l2c, l3c, l4c = 64, 128, 256, 512 # l1c, l2c, l3c, l4c, l5c = 8, 16, 32, 64, 128 # l1c, l2c, l3c, l4c, l5c = 16, 32, 64, 128, 256 l1c, l2c, l3c, l4c, l5c = 32, 64, 128, 256, 512 # branch1 # branch1_layer1 self.b1_1_1 = nn.Sequential( conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b1_1_2_down = down_conv(l1c, l2c) self.b1_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM))) self.b1_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch1_layer2 self.b1_2_1 = BasicBlock(l2c, l2c) self.b1_2_2_down = down_conv(l2c, l3c) self.b1_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM))) self.b1_2_4_up = up_conv(l2c, l1c) # branch1_layer3 self.b1_3_1 = BasicBlock(l3c, l3c) self.b1_3_2_down = down_conv(l3c, l4c) self.b1_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM))) self.b1_3_4_up = up_conv(l3c, l2c) # branch1_layer4 self.b1_4_1 = BasicBlock(l4c, l4c) self.b1_4_2_down = down_conv(l4c, l5c) self.b1_4_2 = BasicBlock(l4c, l4c) self.b1_4_3_down = down_conv(l4c, l4c) self.b1_4_3_same = same_conv(l4c, l4c) self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b1_4_5 = BasicBlock(l4c, l4c) self.b1_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM))) self.b1_4_7_up = up_conv(l4c, l3c) # branch1_layer5 self.b1_5_1 = BasicBlock(l5c, l5c) self.b1_5_2_up = up_conv(l5c, l5c) self.b1_5_2_same = same_conv(l5c, l5c) self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b1_5_4 = BasicBlock(l5c, l5c) self.b1_5_5_up = up_conv(l5c, l4c) # branch2 # branch2_layer1 self.b2_1_1 = nn.Sequential( conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c) ) self.b2_1_2_down = down_conv(l1c, l2c) self.b2_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM))) self.b2_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0) # branch2_layer2 self.b2_2_1 = BasicBlock(l2c, l2c) self.b2_2_2_down = down_conv(l2c, l3c) self.b2_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM))) self.b2_2_4_up = up_conv(l2c, l1c) # branch2_layer3 self.b2_3_1 = BasicBlock(l3c, l3c) self.b2_3_2_down = down_conv(l3c, l4c) self.b2_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM))) self.b2_3_4_up = up_conv(l3c, l2c) # branch2_layer4 self.b2_4_1 = BasicBlock(l4c, l4c) self.b2_4_2_down = down_conv(l4c, l5c) self.b2_4_2 = BasicBlock(l4c, l4c) self.b2_4_3_down = down_conv(l4c, l4c) self.b2_4_3_same = same_conv(l4c, l4c) self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c) self.b2_4_5 = BasicBlock(l4c, l4c) self.b2_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM))) self.b2_4_7_up = up_conv(l4c, l3c) # branch2_layer5 self.b2_5_1 = BasicBlock(l5c, l5c) self.b2_5_2_up = up_conv(l5c, l5c) self.b2_5_2_same = same_conv(l5c, l5c) self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c) self.b2_5_4 = BasicBlock(l5c, l5c) self.b2_5_5_up = up_conv(l5c, l4c) # initialization for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) # elif isinstance(m, nn.BatchNorm3d): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABNSync): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) # elif isinstance(m, InPlaceABN): # nn.init.constant_(m.weight, 1) # nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, input1, input2): # code # branch1 x1_1 = self.b1_1_1(input1) x1_2 = self.b1_1_2_down(x1_1) x1_2 = self.b1_2_1(x1_2) x1_3 = self.b1_2_2_down(x1_2) x1_3 = self.b1_3_1(x1_3) x1_4_1 = self.b1_3_2_down(x1_3) x1_4_1 = self.b1_4_1(x1_4_1) x1_4_2 = self.b1_4_2(x1_4_1) x1_4_3_down = self.b1_4_3_down(x1_4_2) x1_4_3_same = self.b1_4_3_same(x1_4_2) x1_5_1 = self.b1_4_2_down(x1_4_1) x1_5_1 = self.b1_5_1(x1_5_1) x1_5_2_up = self.b1_5_2_up(x1_5_1) x1_5_2_same = self.b1_5_2_same(x1_5_1) # branch2 x2_1 = self.b2_1_1(input2) x2_2 = self.b2_1_2_down(x2_1) x2_2 = self.b2_2_1(x2_2) x2_3 = self.b2_2_2_down(x2_2) x2_3 = self.b2_3_1(x2_3) x2_4_1 = self.b2_3_2_down(x2_3) x2_4_1 = self.b2_4_1(x2_4_1) x2_4_2 = self.b2_4_2(x2_4_1) x2_4_3_down = self.b2_4_3_down(x2_4_2) x2_4_3_same = self.b2_4_3_same(x2_4_2) x2_5_1 = self.b2_4_2_down(x2_4_1) x2_5_1 = self.b2_5_1(x2_5_1) x2_5_2_up = self.b2_5_2_up(x2_5_1) x2_5_2_same = self.b2_5_2_same(x2_5_1) # merge # branch1 x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1) x1_5_3 = self.b1_5_3_transition(x1_5_3) x1_5_3 = self.b1_5_4(x1_5_3) x1_5_3 = self.b1_5_5_up(x1_5_3) x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1) x1_4_4 = self.b1_4_4_transition(x1_4_4) x1_4_4 = self.b1_4_5(x1_4_4) x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1) x1_4_4 = self.b1_4_6(x1_4_4) x1_4_4 = self.b1_4_7_up(x1_4_4) # branch2 x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1) x2_5_3 = self.b2_5_3_transition(x2_5_3) x2_5_3 = self.b2_5_4(x2_5_3) x2_5_3 = self.b2_5_5_up(x2_5_3) x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1) x2_4_4 = self.b2_4_4_transition(x2_4_4) x2_4_4 = self.b2_4_5(x2_4_4) x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1) x2_4_4 = self.b2_4_6(x2_4_4) x2_4_4 = self.b2_4_7_up(x2_4_4) # decode # branch1 x1_3 = torch.cat((x1_3, x1_4_4), dim=1) x1_3 = self.b1_3_3(x1_3) x1_3 = self.b1_3_4_up(x1_3) x1_2 = torch.cat((x1_2, x1_3), dim=1) x1_2 = self.b1_2_3(x1_2) x1_2 = self.b1_2_4_up(x1_2) x1_1 = torch.cat((x1_1, x1_2), dim=1) x1_1 = self.b1_1_3(x1_1) x1_1 = self.b1_1_4(x1_1) # branch2 x2_3 = torch.cat((x2_3, x2_4_4), dim=1) x2_3 = self.b2_3_3(x2_3) x2_3 = self.b2_3_4_up(x2_3) x2_2 = torch.cat((x2_2, x2_3), dim=1) x2_2 = self.b2_2_3(x2_2) x2_2 = self.b2_2_4_up(x2_2) x2_1 = torch.cat((x2_1, x2_2), dim=1) x2_1 = self.b2_1_3(x2_1) x2_1 = self.b2_1_4(x2_1) return x1_1, x2_1 def xnet3d(in_channels, num_classes): model = XNet3D(in_channels, num_classes) return model # if __name__ == '__main__': # # criterion = segmentation_loss('dice', False) # mask = torch.ones(2, 96, 96, 96).long() # model = XNet3D(1, 10) # model.train() # input1 = torch.rand(2,1,96,96,96) # input2 = torch.rand(2,1,96,96,96) # x1_1_main, x1_1_aux1, x1_1_aux2, x1_1_aux3, x2_1_main, x2_1_aux1, x2_1_aux2, x2_1_aux3 = model(input1, input2) # loss_train = criterion(x1_1_main, mask) # loss_train.backward() # # print(output) # print(x1_1_main.data.cpu().numpy().shape) # print(x2_1_main.data.cpu().numpy().shape) # print(loss_train) ================================================ FILE: requirements.txt ================================================ albumentations==0.5.2 einops==0.4.1 MedPy==0.4.0 numpy==1.20.2 opencv_python==4.2.0.34 opencv_python_headless==4.5.1.48 Pillow==8.0.0 PyWavelets==1.1.1 scikit_image==0.18.1 scikit_learn==1.0.1 scipy==1.4.1 SimpleITK==2.1.0 timm==0.6.7 torch==1.8.0+cu111 torchio==0.18.53 torchvision==0.9.0+cu111 tqdm==4.65.0 visdom==0.1.8.9 ================================================ FILE: test.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.train_test_config.train_test_config import print_test_eval, save_test_2d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/GlaS/best_kiunet_Jc_0.7779.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='GlaS', help='CREMI') parser.add_argument('--input1', default='image') parser.add_argument('--if_mask', default=True) parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400') parser.add_argument('-ds', '--deep_supervision', default=False) parser.add_argument('-b', '--batch_size', default=4, type=int) parser.add_argument('-n', '--network', default='kiunet') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # print(path_seg_results) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 # Dataset data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None ) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) num_batches = {'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): inputs_test = data['image'] inputs_test = Variable(inputs_test.cuda(non_blocking=True)) name_test = data['ID'] if args.if_mask: mask_test = data['mask'] mask_test = Variable(mask_test.cuda(non_blocking=True)) outputs_test = model(inputs_test) if args.deep_supervision: outputs_test = outputs_test[0] if args.if_mask: if i == 0: score_list_test = outputs_test name_list_test = name_test mask_list_test = mask_test else: # elif 0 < i <= num_batches['val'] / 16: score_list_test = torch.cat((score_list_test, outputs_test), dim=0) name_list_test = np.append(name_list_test, name_test, axis=0) mask_list_test = torch.cat((mask_list_test, mask_test), dim=0) torch.cuda.empty_cache() else: save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.if_mask: score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_test, score_list_test) score_list_test = torch.cat(score_gather_list_test, dim=0) mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_test, mask_list_test) mask_list_test = torch.cat(mask_gather_list_test, dim=0) name_gather_list_test = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_test, name_list_test) name_list_test = np.concatenate(name_gather_list_test, axis=0) if args.if_mask and rank == args.rank_index: print('=' * print_num) test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus) save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: test_3d.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_3d from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from config.train_test_config.train_test_config import save_test_3d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_result1_Jc_0.7677.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--threshold', default=None) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('--patch_overlap', default=(56, 56, 16)) parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-n', '--network', default='unet3d_min') parser.add_argument('-ds', '--deep_supervision', default=False) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['test'], ) # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() for i, subject in enumerate(dataset_val.dataset_1): grid_sampler = tio.inference.GridSampler( subject=subject, patch_size=args.patch_size, patch_overlap=args.patch_overlap ) # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False) dataloaders = dict() dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16) # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average') with torch.no_grad(): model.eval() for data in dataloaders['test']: inputs_test = Variable(data['image'][tio.DATA].cuda()) location_test = data[tio.LOCATION] outputs_test = model(inputs_test) if args.deep_supervision: outputs_test = outputs_test[0] aggregator.add_batch(outputs_test, location_test) outputs_tensor = aggregator.get_output_tensor() save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine']) if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: test_ConResNet.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_3d from models.getnetwork import get_network from dataload.dataset_3d import dataset_iit_conresnet from config.train_test_config.train_test_config import save_test_3d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/LiTS/best_conresnet_Jc_0.8545.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--input2', default='image_res') parser.add_argument('--threshold', default=None) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('--patch_overlap', default=(56, 56, 16)) parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-n', '--network', default='conresnet') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_val = dataset_iit_conresnet( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, transform_1=data_transform['test'], ) # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() for i, subject in enumerate(dataset_val.dataset_1): grid_sampler = tio.inference.GridSampler( subject=subject, patch_size=args.patch_size, patch_overlap=args.patch_overlap ) # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False) dataloaders = dict() dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16) # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average') with torch.no_grad(): model.eval() for data in dataloaders['test']: inputs_test = Variable(data['image'][tio.DATA].cuda()) inputs_test_2 = Variable(data['image2'][tio.DATA].cuda()) location_test = data[tio.LOCATION] outputs_test = model(inputs_test, inputs_test_2) aggregator.add_batch(outputs_test[0], location_test) outputs_tensor = aggregator.get_output_tensor() save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine']) if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: test_DTC.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_3d from models.getnetwork import get_network from dataload.dataset_3d import dataset_it_dtc from config.train_test_config.train_test_config import save_test_3d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_DTC_Jc_0.7594.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--threshold', default=None) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('--patch_overlap', default=(56, 56, 16)) parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-n', '--network', default='vnet_dtc') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_val = dataset_it_dtc( data_dir=args.path_dataset + '/val', input1=args.input1, num_classes=cfg['NUM_CLASSES'], transform_1=data_transform['test'], ) # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() for i, subject in enumerate(dataset_val.dataset_1): grid_sampler = tio.inference.GridSampler( subject=subject, patch_size=args.patch_size, patch_overlap=args.patch_overlap ) # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False) dataloaders = dict() dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16) # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average') with torch.no_grad(): model.eval() for data in dataloaders['test']: inputs_test = Variable(data['image'][tio.DATA].cuda()) location_test = data[tio.LOCATION] outputs_test_sdf, outputs_test_seg = model(inputs_test) aggregator.add_batch(outputs_test_seg, location_test) outputs_tensor = aggregator.get_output_tensor() save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine']) if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: test_xnet.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_iitnn from config.train_test_config.train_test_config import print_test_eval, save_test_2d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/GlaS/best_result2_Jc_0.7898.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--if_mask', default=True) parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400') parser.add_argument('--if_cct', default=False) parser.add_argument('--result', default='result2', help='result1, result2') parser.add_argument('-n', '--network', default='xnet') parser.add_argument('-b', '--batch_size', default=8, type=int) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # Dataset if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 if args.input2 == 'image': input2_mean = 'MEAN' input2_std = 'STD' else: input2_mean = 'MEAN_' + args.input2 input2_std = 'STD_' + args.input2 data_transforms = data_transform_2d() data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std]) dataset_val = imagefloder_iitnn( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) num_batches = {'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): inputs_test = Variable(data['image'].cuda(non_blocking=True)) inputs_wavelet_test = Variable(data['image_2'].cuda(non_blocking=True)) name_test = data['ID'] if args.if_mask: mask_test = Variable(data['mask'].cuda(non_blocking=True)) if args.if_cct: outputs_test1, outputs_test1_aux1, outputs_test1_aux2, outputs_test1_aux3, outputs_test2, outputs_test2_aux1, outputs_test2_aux2, outputs_test2_aux3 = model(inputs_test, inputs_wavelet_test) else: outputs_test1, outputs_test2 = model(inputs_test, inputs_wavelet_test) if args.result == 'result1': outputs_test = outputs_test1 else: outputs_test = outputs_test2 if args.if_mask: if i == 0: score_list_test = outputs_test name_list_test = name_test mask_list_test = mask_test else: # elif 0 < i <= num_batches['val'] / 16: score_list_test = torch.cat((score_list_test, outputs_test), dim=0) name_list_test = np.append(name_list_test, name_test, axis=0) mask_list_test = torch.cat((mask_list_test, mask_test), dim=0) torch.cuda.empty_cache() else: save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.if_mask: score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_test, score_list_test) score_list_test = torch.cat(score_gather_list_test, dim=0) mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_test, mask_list_test) mask_list_test = torch.cat(mask_gather_list_test, dim=0) name_gather_list_test = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_test, name_list_test) name_list_test = np.concatenate(name_gather_list_test, axis=0) if args.if_mask and rank == args.rank_index: print('=' * print_num) test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus) save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: test_xnet3d.py ================================================ from torchvision import transforms, datasets import torch from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np from torch.backends import cudnn import random import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_3d from models.getnetwork import get_network from dataload.dataset_3d import dataset_iit from config.train_test_config.train_test_config import save_test_3d from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/LiTS/best_result1_Jc_0.7794.pth') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--threshold', default=None) parser.add_argument('--result', default='result1', help='result1, result2') parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('--patch_overlap', default=(56, 56, 16)) parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-n', '--network', default='xnet3d') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) # Config dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 # Results Save if not os.path.exists(args.path_seg_results) and rank == args.rank_index: os.mkdir(args.path_seg_results) path_seg_results = args.path_seg_results + '/' + str(dataset_name) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_val = dataset_iit( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, transform_1=data_transform['test'], ) # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() # if rank == args.rank_index: # state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank)) # model.load_state_dict(state_dict=state_dict) # model = DistributedDataParallel(model, device_ids=[args.local_rank]) model = DistributedDataParallel(model, device_ids=[args.local_rank]) state_dict = torch.load(args.path_model) model.load_state_dict(state_dict=state_dict) dist.barrier() # Test since = time.time() for i, subject in enumerate(dataset_val.dataset_1): grid_sampler = tio.inference.GridSampler( subject=subject, patch_size=args.patch_size, patch_overlap=args.patch_overlap ) # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False) dataloaders = dict() dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16) # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler) aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average') with torch.no_grad(): model.eval() for data in dataloaders['test']: inputs_test_1 = Variable(data['image'][tio.DATA].cuda()) inputs_test_2 = Variable(data['image2'][tio.DATA].cuda()) location_test = data[tio.LOCATION] outputs_test_1, outputs_test_2 = model(inputs_test_1, inputs_test_2) if args.result == 'result1': outputs_test = outputs_test_1 else: outputs_test = outputs_test_2 aggregator.add_batch(outputs_test, location_test) outputs_tensor = aggregator.get_output_tensor() save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine']) if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('-' * print_num) print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('=' * print_num) ================================================ FILE: tools/Atrial/__init__.py ================================================ ================================================ FILE: tools/Atrial/postprocess.py ================================================ import numpy as np import argparse import os import SimpleITK as sitk from skimage.morphology import remove_small_objects, remove_small_holes import skimage def save_max_objects(image): labeled_image = skimage.measure.label(image) labeled_list = skimage.measure.regionprops(labeled_image) box = [] for i in range(len(labeled_list)): box.append(labeled_list[i].area) label_num = box.index(max(box)) + 1 labeled_image[labeled_image != label_num] = 0 labeled_image[labeled_image == label_num] = 1 return labeled_image if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--pred_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730') parser.add_argument('--save_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730_mor') parser.add_argument('--fill_hole_thr', default=500, help='300-500') args = parser.parse_args() if not os.path.exists(args.save_path): os.mkdir(args.save_path) for i in os.listdir(args.pred_path): pred_path = os.path.join(args.pred_path, i) save_path = os.path.join(args.save_path, i) pred = sitk.ReadImage(pred_path) pred = sitk.GetArrayFromImage(pred) pred = pred.astype(bool) pred = remove_small_holes(pred, args.fill_hole_thr) pred = pred.astype(np.uint8) pred = save_max_objects(pred) pred = pred.astype(np.uint8) pred = sitk.GetImageFromArray(pred) sitk.WriteImage(pred, save_path) ================================================ FILE: tools/Atrial/preprocess.py ================================================ import numpy as np import torchio as tio import os import argparse from tqdm import tqdm import SimpleITK as sitk if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/Training Set') parser.add_argument('--save_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/dataset') args = parser.parse_args() if not os.path.exists(args.save_path): os.mkdir(args.save_path) save_image_path = args.save_path + '/image' save_mask_path = args.save_path + '/mask' if not os.path.exists(save_image_path): os.mkdir(save_image_path) if not os.path.exists(save_mask_path): os.mkdir(save_mask_path) for i in os.listdir(args.data_path): save_name = i + '.nrrd' image_path = args.data_path + '/' + i + '/' + 'lgemri.nrrd' mask_path = args.data_path + '/' + i + '/' + 'laendo.nrrd' image = tio.ScalarImage(image_path) mask = tio.LabelMap(mask_path) _, w, h, d = mask.data.shape tempL = np.nonzero(np.array(mask.data)) minx, maxx = np.min(tempL[1]), np.max(tempL[1]) miny, maxy = np.min(tempL[2]), np.max(tempL[2]) # minz, maxz = np.min(tempL[3]), np.max(tempL[3]) px = max(112 - (maxx - minx), 0) // 2 py = max(112 - (maxy - miny), 0) // 2 # pz = max(80 - (maxz - minz), 0) // 2 minx = max(minx - np.random.randint(10, 20) - px, 0) maxx = min(maxx + np.random.randint(10, 20) + px, w) miny = max(miny - np.random.randint(10, 20) - py, 0) maxy = min(maxy + np.random.randint(10, 20) + py, h) # minz = max(minz - np.random.randint(5, 10) - pz, 0) # maxz = min(maxz + np.random.randint(5, 10) + pz, d) image_np = image.data[:, minx:maxx, miny:maxy, :] image.set_data(image_np) mask_np = mask.data[:, minx:maxx, miny:maxy, :] mask.set_data(mask_np) print(image_np.shape) image.save(os.path.join(save_image_path, save_name)) mask.save(os.path.join(save_mask_path, save_name)) ================================================ FILE: tools/LiTS/__init__.py ================================================ ================================================ FILE: tools/LiTS/postprocess.py ================================================ import numpy as np import argparse import os import SimpleITK as sitk from skimage.morphology import remove_small_objects, remove_small_holes import skimage def save_max_objects(image): labeled_image = skimage.measure.label(image) labeled_list = skimage.measure.regionprops(labeled_image) box = [] for i in range(len(labeled_list)): box.append(labeled_list[i].area) label_num = box.index(max(box)) + 1 labeled_image[labeled_image != label_num] = 0 labeled_image[labeled_image == label_num] = 1 return labeled_image if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--pred_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594') parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594_mor') parser.add_argument('--fill_hole_thr', default=100) args = parser.parse_args() if not os.path.exists(args.save_path): os.mkdir(args.save_path) for i in os.listdir(args.pred_path): pred_path = os.path.join(args.pred_path, i) save_path = os.path.join(args.save_path, i) pred = sitk.ReadImage(pred_path) pred = sitk.GetArrayFromImage(pred) pred_ = pred.copy() pred_[pred != 0] = 1 pred_ = pred_.astype(bool) pred_ = remove_small_holes(pred_, args.fill_hole_thr) pred_ = pred_.astype(np.uint8) pred_ = save_max_objects(pred_) pred_[(pred_ == 1) & (pred == 2)] = 2 pred_ = pred_.astype(np.uint8) pred_ = sitk.GetImageFromArray(pred_) sitk.WriteImage(pred_, save_path) ================================================ FILE: tools/LiTS/preprocess.py ================================================ import numpy as np import os import argparse from tqdm import tqdm import SimpleITK as sitk if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='E:/Biomedical datasets/LiTS') parser.add_argument('--save_path', default='E:/Biomedical datasets/LiTS/dataset') parser.add_argument('--min_hu', default=-100) parser.add_argument('--max_hu', default=250) parser.add_argument('--target_spacing', default=[1.00, 1.00, 1.00]) parser.add_argument('--crop_pixel', default=25) args = parser.parse_args() if not os.path.exists(args.save_path): os.mkdir(args.save_path) save_image_path = args.save_path + '/image' save_mask_path = args.save_path + '/mask' if not os.path.exists(save_image_path): os.mkdir(save_image_path) if not os.path.exists(save_mask_path): os.mkdir(save_mask_path) image_path = args.data_path + '/image' mask_path = args.data_path + '/mask' for i in os.listdir(image_path): image_dir = os.path.join(image_path, i) mask_dir = os.path.join(mask_path, i) image = sitk.ReadImage(image_dir) mask = sitk.ReadImage(mask_dir) size = np.array(image.GetSize()) spacing = np.array(image.GetSpacing()) new_size = size * spacing / args.target_spacing new_size = [int(s) for s in new_size] print(new_size, size) resample_image = sitk.ResampleImageFilter() resample_image.SetOutputDirection(image.GetDirection()) resample_image.SetOutputOrigin(image.GetOrigin()) resample_image.SetSize(new_size) resample_image.SetOutputSpacing(args.target_spacing) resample_image.SetInterpolator(sitk.sitkLinear) image = resample_image.Execute(image) resample_mask = sitk.ResampleImageFilter() resample_mask.SetOutputDirection(mask.GetDirection()) resample_mask.SetOutputOrigin(mask.GetOrigin()) resample_mask.SetSize(new_size) resample_mask.SetOutputSpacing(args.target_spacing) resample_mask.SetInterpolator(sitk.sitkNearestNeighbor) mask = resample_mask.Execute(mask) image_np = sitk.GetArrayFromImage(image) mask_np = sitk.GetArrayFromImage(mask) w, h, d = mask_np.shape templ = np.nonzero(mask_np) w_min = max(np.min(templ[0]) - args.crop_pixel, 0) w_max = min(np.max(templ[0]) + args.crop_pixel, w) h_min = max(np.min(templ[1]) - args.crop_pixel, 0) h_max = min(np.max(templ[1]) + args.crop_pixel, h) d_min = max(np.min(templ[2]) - args.crop_pixel, 0) d_max = min(np.max(templ[2]) + args.crop_pixel, d) image_np = image_np[w_min:w_max, h_min:h_max, d_min:d_max] # image_np = image.data image_np[image_np < args.min_hu] = args.min_hu image_np[image_np > args.max_hu] = args.max_hu mask_np = mask_np[w_min:w_max, h_min:h_max, d_min:d_max] image_save = sitk.GetImageFromArray(image_np) image_save.SetSpacing(args.target_spacing) image_save.SetDirection(image.GetDirection()) image_save.SetOrigin(image.GetOrigin()) mask_save = sitk.GetImageFromArray(mask_np) mask_save.SetSpacing(args.target_spacing) mask_save.SetDirection(image.GetDirection()) mask_save.SetOrigin(image.GetOrigin()) sitk.WriteImage(image_save, os.path.join(save_image_path, i)) sitk.WriteImage(mask_save, os.path.join(save_mask_path, i)) # image_save.save(os.path.join(save_image_path, save_name)) # mask_save.save(os.path.join(save_mask_path, save_name)) ================================================ FILE: tools/LiTS/split_train_val.py ================================================ import numpy as np import os import argparse import shutil import random if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/image') parser.add_argument('--mask_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/mask') parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val') parser.add_argument('--amount', default=31) parser.add_argument('--random_seed', default=10) args = parser.parse_args() random.seed(args.random_seed) if not os.path.exists(args.save_path): os.mkdir(args.save_path) save_image_path = args.save_path + '/image' save_mask_path = args.save_path + '/mask' if not os.path.exists(save_image_path): os.mkdir(save_image_path) if not os.path.exists(save_mask_path): os.mkdir(save_mask_path) image_path_list = os.listdir(args.image_path) image_path_list = random.sample(image_path_list, args.amount) for i in image_path_list: shutil.move(os.path.join(args.image_path, i), save_image_path) shutil.move(os.path.join(args.mask_path, i), save_mask_path) ================================================ FILE: tools/__init__.py ================================================ ================================================ FILE: tools/eval.py ================================================ from sklearn.metrics import confusion_matrix import numpy as np import argparse import os from PIL import Image from medpy.metric.binary import hd95, assd import albumentations as A import SimpleITK as sitk def eval_distance(mask_list, seg_result_list, num_classes): print_num = 42 + (num_classes - 3) * 7 print_num_minus = print_num - 2 assert len(mask_list) == len(seg_result_list) if num_classes == 2: hd_list = [] sd_list = [] for i in range(len(mask_list)): if np.any(seg_result_list[i]) and np.any(mask_list[i]): hd_ = hd95(seg_result_list[i], mask_list[i]) sd_ = assd(seg_result_list[i], mask_list[i]) hd_list.append(hd_) sd_list.append(sd_) hd = np.mean(hd_list) sd = np.mean(sd_list) print('| Hd: {:.4f}'.format(hd).ljust(print_num_minus, ' '), '|') print('| Sd: {:.4f}'.format(sd).ljust(print_num_minus, ' '), '|') else: hd_list = [] sd_list = [] for cls in range(num_classes-1): hd_list_ = [] sd_list_ = [] for i in range(len(mask_list)): mask_list_ = mask_list[i].copy() seg_result_list_ = seg_result_list[i].copy() mask_list_[mask_list[i] != (cls + 1)] = 0 seg_result_list_[seg_result_list[i] != (cls + 1)] = 0 if np.any(seg_result_list_) and np.any(mask_list_): hd_ = hd95(seg_result_list_, mask_list_) sd_ = assd(seg_result_list_, mask_list_) hd_list_.append(hd_) sd_list_.append(sd_) hd = np.mean(hd_list_) sd = np.mean(sd_list_) hd_list.append(hd) sd_list.append(sd) hd_list = np.array(hd_list) sd_list = np.array(sd_list) m_hd = np.mean(hd_list) m_sd = np.mean(sd_list) np.set_printoptions(precision=4, suppress=True) print('| Hd: {}'.format(hd_list).ljust(print_num_minus, ' '), '|') print('| Sd: {}'.format(sd_list).ljust(print_num_minus, ' '), '|') print('| mHd: {:.4f}'.format(m_hd).ljust(print_num_minus, ' '), '|') print('| mSd: {:.4f}'.format(m_sd).ljust(print_num_minus, ' '), '|') print('-' * print_num) def eval_pixel(mask_list, seg_result_list, num_classes): c = confusion_matrix(mask_list, seg_result_list) hist_diag = np.diag(c) hist_sum_0 = c.sum(axis=0) hist_sum_1 = c.sum(axis=1) jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag) dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0) print_num = 42 + (num_classes - 3) * 7 print_num_minus = print_num - 2 print('-' * print_num) if num_classes > 2: m_jaccard = np.nanmean(jaccard) m_dice = np.nanmean(dice) np.set_printoptions(precision=4, suppress=True) print('| Jc: {}'.format(jaccard).ljust(print_num_minus, ' '), '|') print('| Dc: {}'.format(dice).ljust(print_num_minus, ' '), '|') print('| mJc: {:.4f}'.format(m_jaccard).ljust(print_num_minus, ' '), '|') print('| mDc: {:.4f}'.format(m_dice).ljust(print_num_minus, ' '), '|') else: print('| Jc: {:.4f}'.format(jaccard[1]).ljust(print_num_minus, ' '), '|') print('| Dc: {:.4f}'.format(dice[1]).ljust(print_num_minus, ' '), '|') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--pred_path', default='/mnt/data1/XNet/seg_pred/test/LiTS/best_result1_Jc_0.7677_mor') parser.add_argument('--mask_path', default='/mnt/data1/XNet/dataset/LiTS/val/mask') parser.add_argument('--if_3D', default=True) parser.add_argument('--resize_shape', default=(128, 128)) parser.add_argument('--num_classes', default=3) args = parser.parse_args() pred_list = [] mask_list = [] pred_flatten_list = [] mask_flatten_list = [] num = 0 for i in os.listdir(args.pred_path): pred_path = os.path.join(args.pred_path, i) mask_path = os.path.join(args.mask_path, i) if args.if_3D: pred = sitk.ReadImage(pred_path) pred = sitk.GetArrayFromImage(pred) mask = sitk.ReadImage(mask_path) mask = sitk.GetArrayFromImage(mask) else: pred = Image.open(pred_path) # pred = pred.resize((args.resize_shape[1], args.resize_shape[0])) pred = np.array(pred) mask = Image.open(mask_path) # mask = mask.resize((args.resize_shape[1], args.resize_shape[0])) mask = np.array(mask) resize = A.Resize(args.resize_shape[1], args.resize_shape[0], p=1)(image=pred, mask=mask) mask = resize['mask'] pred = resize['image'] pred_list.append(pred) mask_list.append(mask) if num == 0: pred_flatten_list = pred.flatten() mask_flatten_list = mask.flatten() else: pred_flatten_list = np.append(pred_flatten_list, pred.flatten()) mask_flatten_list = np.append(mask_flatten_list, mask.flatten()) num += 1 eval_pixel(mask_flatten_list, pred_flatten_list, args.num_classes) eval_distance(mask_list, pred_list, args.num_classes) ================================================ FILE: tools/mask2sdf.py ================================================ import numpy as np import os import argparse import SimpleITK as sitk from scipy.ndimage import distance_transform_edt from skimage import segmentation if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val') parser.add_argument('--num_classes', default=3) args = parser.parse_args() mask_path = args.data_path + '/mask' for i in range(args.num_classes-1): save_sdf_mask_path = args.data_path + '/mask_sdf' + str(i+1) if not os.path.exists(save_sdf_mask_path): os.mkdir(save_sdf_mask_path) for j in os.listdir(mask_path): mask = sitk.ReadImage(os.path.join(mask_path, j)) mask_np = sitk.GetArrayFromImage(mask) mask_np[mask_np != (i+1)] = 0 mask_np = mask_np.astype(bool) if mask_np.any(): mask_neg = ~mask_np posdis = distance_transform_edt(mask_np) negdis = distance_transform_edt(mask_neg) boundary = segmentation.find_boundaries(mask_np, mode='inner').astype(np.uint8) sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) sdf[boundary == 1] = 0 # sdf = ((sdf - np.min(sdf)) / (np.max(sdf) - np.min(sdf))) * 255 else: sdf = np.zeros(mask_np.shape) sdf = sitk.GetImageFromArray(sdf) sdf.SetSpacing(mask.GetSpacing()) sdf.SetDirection(mask.GetDirection()) sdf.SetOrigin(mask.GetOrigin()) sitk.WriteImage(sdf, os.path.join(save_sdf_mask_path, j)) ================================================ FILE: tools/res_image_mask.py ================================================ import numpy as np import os import argparse import SimpleITK as sitk if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/Atrial/train_sup_100') args = parser.parse_args() image_path = args.data_path + '/image' mask_path = args.data_path + '/mask' save_res_path = args.data_path + '/image_res' save_res_mask_path = args.data_path + '/mask_res' if not os.path.exists(save_res_path): os.mkdir(save_res_path) if not os.path.exists(save_res_mask_path): os.mkdir(save_res_mask_path) for i in os.listdir(image_path): image = sitk.ReadImage(os.path.join(image_path, i)) image_np = sitk.GetArrayFromImage(image) mask = sitk.ReadImage(os.path.join(mask_path, i)) mask_np = sitk.GetArrayFromImage(mask) image_copy = np.zeros(image_np.shape) image_copy[1:, :, :] = image_np[0:image_np.shape[0] - 1, :, :] image_res = image_np - image_copy image_res[0, :, :] = 0 image_res = np.abs(image_res) image_res = sitk.GetImageFromArray(image_res) image_res.SetSpacing(image.GetSpacing()) image_res.SetDirection(image.GetDirection()) image_res.SetOrigin(image.GetOrigin()) mask_copy = np.zeros(mask_np.shape) mask_copy[1:, :, :] = mask_np[0:mask_np.shape[0] - 1, :, :] mask_res = mask_np - mask_copy mask_res[0, :, :] = 0 mask_res = np.abs(mask_res) mask_res = sitk.GetImageFromArray(mask_res) mask_res.SetSpacing(image.GetSpacing()) mask_res.SetDirection(image.GetDirection()) mask_res.SetOrigin(image.GetOrigin()) sitk.WriteImage(image_res, os.path.join(save_res_path, i)) sitk.WriteImage(mask_res, os.path.join(save_res_mask_path, i)) ================================================ FILE: tools/wavelet2D.py ================================================ import numpy as np from PIL import Image import pywt import argparse import os if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/image') parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/L') parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/H') parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey') parser.add_argument('--if_RGB', default=False) args = parser.parse_args() if not os.path.exists(args.L_path): os.mkdir(args.L_path) if not os.path.exists(args.H_path): os.mkdir(args.H_path) for i in os.listdir(args.image_path): image_path = os.path.join(args.image_path, i) L_path = os.path.join(args.L_path, i) H_path = os.path.join(args.H_path, i) if args.if_RGB: image = Image.open(image_path).convert('L') else: image = Image.open(image_path) image = np.array(image) LL, (LH, HL, HH) = pywt.dwt2(image, args.wavelet_type) LL = (LL - LL.min()) / (LL.max() - LL.min()) * 255 LL = Image.fromarray(LL.astype(np.uint8)) LL.save(L_path) LH = (LH - LH.min()) / (LH.max() - LH.min()) * 255 HL = (HL - HL.min()) / (HL.max() - HL.min()) * 255 HH = (HH - HH.min()) / (HH.max() - HH.min()) * 255 merge1 = HH + HL + LH merge1 = (merge1-merge1.min()) / (merge1.max()-merge1.min()) * 255 merge1 = Image.fromarray(merge1.astype(np.uint8)) merge1.save(H_path) ================================================ FILE: tools/wavelet3D.py ================================================ import numpy as np from PIL import Image import pywt import argparse import os import SimpleITK as sitk import torchio as tio if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val/image') parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/L') parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/H') parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey') args = parser.parse_args() if not os.path.exists(args.L_path): os.mkdir(args.L_path) if not os.path.exists(args.H_path): os.mkdir(args.H_path) for i in os.listdir(args.image_path): image_path = os.path.join(args.image_path, i) L_path = os.path.join(args.L_path, i) H_path = os.path.join(args.H_path, i) image = sitk.ReadImage(image_path) image_np = sitk.GetArrayFromImage(image) image_wave = pywt.dwtn(image_np, args.wavelet_type) LLL = image_wave['aaa'] LLH = image_wave['aad'] LHL = image_wave['ada'] LHH = image_wave['add'] HLL = image_wave['daa'] HLH = image_wave['dad'] HHL = image_wave['dda'] HHH = image_wave['ddd'] LLL = (LLL - LLL.min()) / (LLL.max() - LLL.min()) * 255 resample_image = sitk.ResampleImageFilter() resample_image.SetSize(image.GetSize()) resample_image.SetOutputSpacing([0.5, 0.5, 0.5]) resample_image.SetInterpolator(sitk.sitkLinear) LLL = resample_image.Execute(LLL) LLL.SetSpacing(image.GetSpacing()) LLL.SetDirection(image.GetDirection()) LLL.SetOrigin(image.GetOrigin()) sitk.WriteImage(LLL, L_path) LLH = (LLH - LLH.min()) / (LLH.max() - LLH.min()) * 255 LHL = (LHL - LHL.min()) / (LHL.max() - LHL.min()) * 255 LHH = (LHH - LHH.min()) / (LHH.max() - LHH.min()) * 255 HLL = (HLL - HLL.min()) / (HLL.max() - HLL.min()) * 255 HLH = (HLH - HLH.min()) / (HLH.max() - HLH.min()) * 255 HHL = (HHL - HHL.min()) / (HHL.max() - HHL.min()) * 255 HHH = (HHH - HHH.min()) / (HHH.max() - HHH.min()) * 255 merge1 = LLH + LHL + LHH + HLL + HLH + HHL + HHH merge1 = (merge1 - merge1.min()) / (merge1.max() - merge1.min()) * 255 merge1 = sitk.GetImageFromArray(merge1) resample_image = sitk.ResampleImageFilter() resample_image.SetSize(image.GetSize()) resample_image.SetOutputSpacing([0.5, 0.5, 0.5]) resample_image.SetInterpolator(sitk.sitkLinear) merge1 = resample_image.Execute(merge1) merge1.SetSpacing(image.GetSpacing()) merge1.SetDirection(image.GetDirection()) merge1.SetOrigin(image.GetOrigin()) sitk.WriteImage(merge1, H_path) ================================================ FILE: train_semi_CCT.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet_cct', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() kl_distance = nn.KLDivLoss(reduction='none') optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) optimizer1.zero_grad() pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1) pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1) consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2) consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2) consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2) loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3 loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup) + criterion(pred_train_sup2, mask_train_sup) + criterion(pred_train_sup3, mask_train_sup) + criterion(pred_train_sup4, mask_train_sup)) / 4 loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) val_loss_sup_1 += loss_val_sup1.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'CCT') torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1) visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_CCT_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=8, type=int) parser.add_argument('--samples_per_volume_val', default=12, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d_cct', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight)+ '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank], find_unused_parameters=True) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() kl_distance = nn.KLDivLoss(reduction='none') optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1) pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1) consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2) consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2) consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2) loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3 loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4 loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) val_loss_sup_1 += loss_val_sup_1.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'CCT', cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_CPS.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet_sb', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-CPS-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_'+args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma) scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2) since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup = unsup_index['image'] img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True)) optimizer1.zero_grad() optimizer2.zero_grad() pred_train_unsup1 = model1(img_train_unsup) pred_train_unsup2 = model2(img_train_unsup) max_train1 = torch.max(pred_train_unsup1, dim=1)[1] max_train2 = torch.max(pred_train_unsup2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1 = model1(img_train_sup) pred_train_sup2 = model2(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer1.step() optimizer2.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup1.step() scheduler_warmup2.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() optimizer2.zero_grad() outputs_val1 = model1(inputs_val) outputs_val2 = model2(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2) visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_CPS_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=8, type=int) parser.add_argument('--samples_per_volume_val', default=12, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d_min', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-CPS-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma) scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2) # Train & Val since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() optimizer2.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup2 = model2(img_train_unsup_1) max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1] max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1] max_train_unsup1 = max_train_unsup1.long() max_train_unsup2 = max_train_unsup2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1 = model1(img_train_sup_1) pred_train_sup2 = model2(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer1.step() optimizer2.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup1.step() scheduler_warmup2.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda().cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() optimizer2.zero_grad() outputs_val_1 = model1(inputs_val_1) outputs_val_2 = model2(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_CT.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n1', '--network1', default='unet', type=str) parser.add_argument('-n2', '--network2', default='swinunet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma) scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2) since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup = unsup_index['image'] img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True)) optimizer1.zero_grad() optimizer2.zero_grad() pred_train_unsup1 = model1(img_train_unsup) pred_train_unsup2 = model2(img_train_unsup) max_train1 = torch.max(pred_train_unsup1, dim=1)[1] max_train2 = torch.max(pred_train_unsup2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1 = model1(img_train_sup) pred_train_sup2 = model2(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer1.step() optimizer2.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup1.step() scheduler_warmup2.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() optimizer2.zero_grad() outputs_val1 = model1(inputs_val) outputs_val2 = model2(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2) visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_CT_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=8, type=int) parser.add_argument('--samples_per_volume_val', default=12, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n1', '--network1', default='unet3d_min', type=str) parser.add_argument('-n2', '--network2', default='unertr', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model1 = model1.cuda() model2 = model2.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank], find_unused_parameters=True) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma) scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2) # Train & Val since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() optimizer2.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup2 = model2(img_train_unsup_1) max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1] max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1] max_train_unsup1 = max_train_unsup1.long() max_train_unsup2 = max_train_unsup2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1 = model1(img_train_sup_1) pred_train_sup2 = model2(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer1.step() optimizer2.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup1.step() scheduler_warmup2.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() optimizer2.zero_grad() outputs_val_1 = model1(inputs_val_1) outputs_val_2 = model2(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_DTC.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it_dtc from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS') parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=1, type=float) parser.add_argument('--beta', default=0.3, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(112, 112, 32)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=8, type=int) parser.add_argument('--samples_per_volume_val', default=12, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d_dtc', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-DTC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it_dtc( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, num_classes=cfg['NUM_CLASSES'], transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it_dtc( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, num_classes=cfg['NUM_CLASSES'], transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it_dtc( data_dir=args.path_dataset + '/val', input1=args.input1, num_classes=cfg['NUM_CLASSES'], transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() mseloss = torch.nn.MSELoss().cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() pred_train_unsup_sdf, pred_train_unsup_seg = model1(img_train_unsup_1) pred_train_unsup_seg_soft = torch.sigmoid(pred_train_unsup_seg) dis_to_mask = torch.sigmoid(-1500 * pred_train_unsup_sdf) loss_train_unsup = torch.mean((dis_to_mask - pred_train_unsup_seg_soft) ** 2) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) mask_train_sup_sdf1 = Variable(sup_index['mask2'][tio.DATA].squeeze(1).float().cuda()) if cfg['NUM_CLASSES'] == 3: mask_train_sup_sdf2 = Variable(sup_index['mask3'][tio.DATA].squeeze(1).float().cuda()) pred_train_sup_sdf, pred_train_sup_seg = model1(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup_seg mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup_seg), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) if cfg['NUM_CLASSES'] == 3: loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1) + mseloss(pred_train_sup_sdf[:, 2, ...], mask_train_sup_sdf2) else: loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1) loss_train_seg = criterion(pred_train_sup_seg, mask_train_sup) loss_train_sup = loss_train_seg + args.beta * loss_train_sdf loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup += loss_train_sup.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_sdf, outputs_val_seg = model1(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_seg mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_seg), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_seg, mask_val) val_loss_sup_1 += loss_val_sup_1.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'DTC', cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_EM.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss, entropy_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) loss_train_unsup = entropy_loss(pred_train_unsup1, C=2) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1 = model1(img_train_sup) pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() outputs_val1 = model1(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) val_loss_sup_1 += loss_val_sup1.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'EM') torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1) visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_EM_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss, entropy_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=50, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) loss_train_unsup = entropy_loss(pred_train_unsup1, C=2) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1 = model1(img_train_sup_1) pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_1 = model1(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) val_loss_sup_1 += loss_val_sup_1.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'EM', cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust( print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_MT.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def update_ema_variables(model, ema_model, alpha, global_step): # Use the true average until the exponential average is more correct alpha = min(1 - 1 / (global_step + 1), alpha) for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--ema_decay', default=0.99, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() # for param in model2.parameters(): # param.detach_() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2) img_train_unsup_2 = img_train_unsup_1 + noise optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) with torch.no_grad(): pred_train_unsup2 = model2(img_train_unsup_2) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1 = model1(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() update_ema_variables(model1, model2, args.ema_decay, epoch) torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus) train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() outputs_val1 = model1(inputs_val) outputs_val2 = model2(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2) visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_MT_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def update_ema_variables(model, ema_model, alpha, global_step): # Use the true average until the exponential average is more correct alpha = min(1 - 1 / (global_step + 1), alpha) for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--ema_decay', default=0.99, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() # for param in model2.parameters(): # param.detach_() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2) img_train_unsup_2 = img_train_unsup_1 + noise optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) with torch.no_grad(): pred_train_unsup2 = model2(img_train_unsup_2) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1 = model1(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() update_ema_variables(model1, model2, args.ema_decay, epoch) torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_1 = model1(inputs_val_1) outputs_val_2 = model2(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_UAMT.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss, softmax_mse_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best from warnings import simplefilter from config.ramps import ramps simplefilter(action='ignore', category=FutureWarning) def update_ema_variables(model, ema_model, alpha, global_step): # Use the true average until the exponential average is more correct alpha = min(1 - 1 / (global_step + 1), alpha) for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=0.05, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--ema_decay', default=0.99, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() # for param in model2.parameters(): # param.detach_() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2) img_train_unsup_2 = img_train_unsup_1 + noise optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) with torch.no_grad(): pred_train_unsup2 = model2(img_train_unsup_2) T = 8 _, _, w, h = img_train_unsup_1.shape volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1) stride = volume_batch_r.shape[0] // 2 preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], w, h]).cuda() for i_ in range(T // 2): ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2) with torch.no_grad(): preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs) preds = torch.softmax(preds, dim=1) preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], w, h) preds = torch.mean(preds, dim=0) uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2) # (batch, 2, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2) mask = (uncertainty < threshold).float() loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1 = model1(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() update_ema_variables(model1, model2, args.ema_decay, epoch) torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus) train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() outputs_val1 = model1(inputs_val) outputs_val2 = model2(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2) visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_UAMT_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from config.ramps import ramps from loss.loss_function import segmentation_loss, softmax_mse_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def update_ema_variables(model, ema_model, alpha, global_step): # Use the true average until the exponential average is more correct alpha = min(1 - 1 / (global_step + 1), alpha) for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--ema_decay', default=0.99, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() # for param in model2.parameters(): # param.detach_() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2) img_train_unsup_2 = img_train_unsup_1 + noise optimizer1.zero_grad() pred_train_unsup1 = model1(img_train_unsup_1) with torch.no_grad(): pred_train_unsup2 = model2(img_train_unsup_2) T = 8 _, _, d, w, h = img_train_unsup_1.shape volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1, 1) stride = volume_batch_r.shape[0] // 2 preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], d, w, h]).cuda() for i_ in range(T // 2): ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2) with torch.no_grad(): preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs) preds = torch.softmax(preds, dim=1) preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], d, w, h) preds = torch.mean(preds, dim=0) uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2) # (batch, 2, 112,112,80) threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2) mask = (uncertainty < threshold).float() loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1 = model1(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() update_ema_variables(model1, model2, args.ema_decay, epoch) torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_1 = model1(inputs_val_1) outputs_val_2 = model2(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust( print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_URPC.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss, entropy_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=1, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet_urpc', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-URPC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train_unsup = imagefloder_itn( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() kl_distance = nn.KLDivLoss(reduction='none') optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) optimizer1.zero_grad() pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1) pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1) preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4 variance_aux1 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup1), dim=1, keepdim=True) exp_variance_aux1 = torch.exp(-variance_aux1) variance_aux2 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup2), dim=1, keepdim=True) exp_variance_aux2 = torch.exp(-variance_aux2) variance_aux3 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup3), dim=1, keepdim=True) exp_variance_aux3 = torch.exp(-variance_aux3) variance_aux4 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup4), dim=1, keepdim=True) exp_variance_aux4 = torch.exp(-variance_aux4) consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2 consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1) consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2 consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2) consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2 consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3) consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2 consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4) loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4 loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup = sup_index['image'] img_train_sup = Variable(img_train_sup.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4 loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val = Variable(data['image'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer1.zero_grad() outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) val_loss_sup_1 += loss_val_sup1.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'URPC') torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1) visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_URPC_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-c', '--unsup_weight', default=5, type=float) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet3d_urpc', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Semi-UPRC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)) visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_it( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model1 = model1.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() kl_distance = nn.KLDivLoss(reduction='none') optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model1.train() train_loss_sup_1 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) optimizer1.zero_grad() pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1) pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1) pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1) pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1) pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1) preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4 variance_aux1 = torch.sum(kl_distance(torch.log(pred_train_unsup1), preds), dim=1, keepdim=True) exp_variance_aux1 = torch.exp(-variance_aux1) variance_aux2 = torch.sum(kl_distance(torch.log(pred_train_unsup2), preds), dim=1, keepdim=True) exp_variance_aux2 = torch.exp(-variance_aux2) variance_aux3 = torch.sum(kl_distance(torch.log(pred_train_unsup3), preds), dim=1, keepdim=True) exp_variance_aux3 = torch.exp(-variance_aux3) variance_aux4 = torch.sum(kl_distance(torch.log(pred_train_unsup4), preds), dim=1, keepdim=True) exp_variance_aux4 = torch.exp(-variance_aux4) consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2 consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1) consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2 consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2) consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2 consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3) consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2 consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4) loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4 loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4 loss_train_sup = loss_train_sup1 loss_train_sup.backward() optimizer1.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss += loss_train.item() scheduler_warmup1.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus) train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer1.zero_grad() outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) val_loss_sup_1 += loss_val_sup_1.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus) val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'URPC', cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_XNet.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader from models.getnetwork import get_network import argparse import time import os import numpy as np from torch.backends import cudnn import random from PIL import Image import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel import sys from config.dataset_config.dataset_cfg import dataset_cfg from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_iitnn from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet') parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice') parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672) args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) # trained model save path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) # seg results save path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) # vis if args.vis and rank == args.rank_index: visdom_env = str('Semi-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 if args.input2 == 'image': input2_mean = 'MEAN' input2_std = 'STD' else: input2_mean = 'MEAN_' + args.input2 input2_std = 'STD_' + args.input2 data_transforms = data_transform_2d() data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std]) dataset_train_unsup = imagefloder_iitnn( data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark, input1=args.input1, input2=args.input2, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=False, num_images=None, ) num_images_unsup = len(dataset_train_unsup) dataset_train_sup = imagefloder_iitnn( data_dir=args.path_dataset + '/train_sup_'+args.sup_mark, input1=args.input1, input2=args.input2, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=num_images_unsup, ) dataset_val = imagefloder_iitnn( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) dist.barrier() criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) since = time.time() count_iter = 0 best_model = model best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = unsup_index['image'] img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True)) img_train_unsup_2 = unsup_index['image_2'] img_train_unsup_2 = Variable(img_train_unsup_2.cuda(non_blocking=True)) optimizer.zero_grad() pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2) max_train1 = torch.max(pred_train_unsup1, dim=1)[1] max_train2 = torch.max(pred_train_unsup2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = sup_index['image'] img_train_sup_1 = Variable(img_train_sup_1.cuda(non_blocking=True)) img_train_sup_2 = sup_index['image_2'] img_train_sup_2 = Variable(img_train_sup_2.cuda(non_blocking=True)) mask_train_sup = sup_index['mask'] mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True)) pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2) if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val'] / 16: inputs_val_1 = Variable(data['image'].cuda(non_blocking=True)) inputs_val_2 = Variable(data['image_2'].cuda(non_blocking=True)) mask_val = Variable(data['mask'].cuda(non_blocking=True)) name_val = data['ID'] optimizer.zero_grad() outputs_val1, outputs_val2 = model(inputs_val_1, inputs_val_2) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2) visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_semi_XNet3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_iit from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--input2', default='DB2_H') parser.add_argument('--sup_mark', default='20') parser.add_argument('--unsup_mark', default='80') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet3d', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results_1 = path_seg_results + '/pred' if not os.path.exists(path_seg_results_1) and rank == args.rank_index: os.mkdir(path_seg_results_1) if args.vis and rank == args.rank_index: visdom_env = str('Semi-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_unsup = dataset_iit( data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark, input1=args.input1, input2=args.input2, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=False, num_images=None ) num_images_unsup = len(dataset_train_unsup.dataset_1) dataset_train_sup = dataset_iit( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, input2=args.input2, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=num_images_unsup ) dataset_val = dataset_iit( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True) train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup) dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_model = model best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train_sup'].sampler.set_epoch(epoch) dataloaders['train_unsup'].sampler.set_epoch(epoch) model.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() dataset_train_sup = iter(dataloaders['train_sup']) dataset_train_unsup = iter(dataloaders['train_unsup']) for i in range(num_batches['train_sup']): unsup_index = next(dataset_train_unsup) img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda()) img_train_unsup_2 = Variable(unsup_index['image2'][tio.DATA].cuda()) optimizer.zero_grad() pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2) max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1] max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1] max_train_unsup1 = max_train_unsup1.long() max_train_unsup2 = max_train_unsup2.long() loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train_unsup.backward(retain_graph=True) torch.cuda.empty_cache() sup_index = next(dataset_train_sup) img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda()) img_train_sup_2 = Variable(sup_index['image2'][tio.DATA].cuda()) mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda()) pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = pred_train_sup1 score_list_train2 = pred_train_sup2 mask_list_train = mask_train_sup # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0) score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0) loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup) loss_train_sup = loss_train_sup1 + loss_train_sup2 loss_train_sup.backward() optimizer.step() torch.cuda.empty_cache() loss_train = loss_train_unsup + loss_train_sup train_loss_unsup += loss_train_unsup.item() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) inputs_val_2 = Variable(data['image2'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_itn from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=4, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-ds', '--deep_supervision', default=False) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='unet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)) visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port) if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 # Dataset data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) dataset_train = imagefloder_itn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize, sup=True, num_images=None, ) dataset_val = imagefloder_itn( data_dir=args.path_dataset + '/val', input1=args.input1, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize, sup=True, num_images=None, ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss = 0.0 val_loss = 0.0 dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train = Variable(data['image'].cuda()) mask_train = Variable(data['mask'].cuda()) optimizer.zero_grad() outputs_train = model(inputs_train) torch.cuda.empty_cache() if args.deep_supervision: loss_train = 0 for output_train in outputs_train: loss_train += criterion(output_train, mask_train) loss_train /= len(outputs_train) outputs_train = outputs_train[0] else: loss_train = criterion(outputs_train, mask_train) loss_train.backward() optimizer.step() train_loss += loss_train.item() if count_iter % args.display_iter == 0: if i == 0: score_list_train = outputs_train mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 4: score_list_train = torch.cat((score_list_train, outputs_train), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train, score_list_train) score_list_train = torch.cat(score_gather_list_train, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus) train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val = Variable(data['image'].cuda()) mask_val = Variable(data['mask'].cuda()) name_val = data['ID'] optimizer.zero_grad() outputs_val = model(inputs_val) torch.cuda.empty_cache() if args.deep_supervision: loss_val = 0 for output_val in outputs_val: loss_val += criterion(output_val, mask_val) loss_val /= len(outputs_val) outputs_val = outputs_val[0] else: loss_val = criterion(outputs_val, mask_val) val_loss += loss_val.item() if i == 0: score_list_val = outputs_val mask_list_val = mask_val name_list_val = name_val else: score_list_val = torch.cat((score_list_val, outputs_val), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) torch.cuda.empty_cache() score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val, score_list_val) score_list_val = torch.cat(score_gather_list_val, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus) val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list) visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_it from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.005, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='vnet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results+'/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)) visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_sup = dataset_it( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None ) dataset_val = dataset_it( data_dir=args.path_dataset + '/val', input1=args.input1, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss = 0.0 val_loss = 0.0 dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train = Variable(data['image'][tio.DATA].cuda()) mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_train = model(inputs_train) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train = outputs_train mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train = torch.cat((score_list_train, outputs_train), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) loss_train = criterion(outputs_train, mask_train) loss_train.backward() optimizer.step() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train, score_list_train) score_list_train = torch.cat(score_gather_list_train, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus) train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val = Variable(data['image'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_val = model(inputs_val) torch.cuda.empty_cache() if i == 0: score_list_val = outputs_val mask_list_val = mask_val else: score_list_val = torch.cat((score_list_val, outputs_val), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val = criterion(outputs_val, mask_val) val_loss += loss_val.item() torch.cuda.empty_cache() score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val, score_list_val) score_list_val = torch.cat(score_gather_list_val, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus) val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_sup(visdom, epoch + 1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_ConResNet.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_ConResNet, print_val_loss_ConResNet, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_ConResNet, visualization_ConResNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_iit_conresnet from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='image') parser.add_argument('--input2', default='image_res') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.1, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='conresnet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results = path_seg_results + '/pred' if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-ConResNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)) visdom = visdom_initialization_ConResNet(env=visdom_env, port=args.visdom_port) data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_sup = dataset_iit_conresnet( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, input2=args.input2, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None, ) dataset_val = dataset_iit_conresnet( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) # Training Strategy criterion_dice = segmentation_loss('dice', False).cuda() criterion_ce = segmentation_loss('CE', False).cuda() criterion_bound = segmentation_loss('bcebound', False, num_classes=cfg['NUM_CLASSES']).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss_seg = 0.0 train_loss_res = 0.0 train_loss = 0.0 val_loss_seg = 0.0 val_loss_res = 0.0 dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train_1 = Variable(data['image'][tio.DATA].cuda()) inputs_train_2 = Variable(data['image2'][tio.DATA].cuda()) mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) mask_train_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_train = model(inputs_train_1, inputs_train_2) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train = outputs_train[0] mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train = torch.cat((score_list_train, outputs_train[0]), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) loss_train_seg = criterion_dice(outputs_train[0], mask_train) + criterion_ce(outputs_train[0], mask_train) loss_train_res = criterion_bound(outputs_train[1], mask_train_2) + 0.5 * (criterion_bound(outputs_train[2], mask_train_2) + criterion_bound(outputs_train[3], mask_train_2)) loss_train = loss_train_seg + loss_train_res loss_train.backward() optimizer.step() train_loss_seg += loss_train_seg.item() train_loss_res += loss_train_res.item() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train, score_list_train) score_list_train = torch.cat(score_gather_list_train, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss = print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus) train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val = Variable(data['image'][tio.DATA].cuda()) inputs_val_2 = Variable(data['image2'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) mask_val_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_val = model(inputs_val, inputs_val_2) torch.cuda.empty_cache() if i == 0: score_list_val = outputs_val[0] mask_list_val = mask_val else: score_list_val = torch.cat((score_list_val, outputs_val[0]), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_seg = criterion_dice(outputs_val[0], mask_val) + criterion_ce(outputs_val[0], mask_val) loss_val_res = criterion_bound(outputs_val[1], mask_val_2) + 0.5 * (criterion_bound(outputs_val[2], mask_val_2) + criterion_bound(outputs_val[3], mask_val_2)) val_loss_seg += loss_val_seg.item() val_loss_res += loss_val_res.item() torch.cuda.empty_cache() score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val, score_list_val) score_list_val = torch.cat(score_gather_list_val, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_seg, val_epoch_loss_res = print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half) val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_ConResNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_seg, train_epoch_loss_res, train_m_jc, val_epoch_loss_seg, val_epoch_loss_res, val_m_jc) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_XNet.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_iitnn from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 if args.input2 == 'image': input2_mean = 'MEAN' input2_std = 'STD' else: input2_mean = 'MEAN_' + args.input2 input2_std = 'STD_' + args.input2 data_transforms = data_transform_2d() data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std]) dataset_train = imagefloder_iitnn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, input2=args.input2, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) dataset_val = imagefloder_iitnn( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_model = model best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train_1 = Variable(data['image'].cuda()) inputs_train_2 = Variable(data['image_2'].cuda()) mask_train = Variable(data['mask'].cuda()) optimizer.zero_grad() outputs_train1, outputs_train2 = model(inputs_train_1, inputs_train_2) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = outputs_train1 score_list_train2 = outputs_train2 mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 4: score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0) score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) max_train1 = torch.max(outputs_train1, dim=1)[1] max_train2 = torch.max(outputs_train2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_sup1 = criterion(outputs_train1, mask_train) loss_train_sup2 = criterion(outputs_train2, mask_train) loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup loss_train.backward() optimizer.step() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss_unsup += loss_train_unsup.item() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val = Variable(data['image'].cuda()) inputs_val_wavelet = Variable(data['image_2'].cuda()) mask_val = Variable(data['mask'].cuda()) name_val = data['ID'] optimizer.zero_grad() outputs_val1, outputs_val2 = model(inputs_val, inputs_val_wavelet) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2) visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_XNet3d.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random import torchio as tio from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_3d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_3d import dataset_iit from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial') parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--sup_mark', default='100', help='100') parser.add_argument('-b', '--batch_size', default=1, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('--patch_size', default=(96, 96, 80)) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('--queue_length', default=48, type=int) parser.add_argument('--samples_per_volume_train', default=4, type=int) parser.add_argument('--samples_per_volume_val', default=8, type=int) parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet3d', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_mask_results = path_seg_results + '/mask' if not os.path.exists(path_mask_results) and rank == args.rank_index: os.mkdir(path_mask_results) path_seg_results_1 = path_seg_results + '/pred' if not os.path.exists(path_seg_results_1) and rank == args.rank_index: os.mkdir(path_seg_results_1) if args.vis and rank == args.rank_index: visdom_env = str('Sup-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))+str(args.input1)+'-'+str(args.input2) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset data_transform = data_transform_3d(cfg['NORMALIZE']) dataset_train_sup = dataset_iit( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, input2=args.input2, transform_1=data_transform['train'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_train, patch_size=args.patch_size, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None ) dataset_val = dataset_iit( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, transform_1=data_transform['val'], queue_length=args.queue_length, samples_per_volume=args.samples_per_volume_val, patch_size=args.patch_size, num_workers=8, shuffle_subjects=False, shuffle_patches=False, sup=True, num_images=None ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_model = model best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train_1 = Variable(data['image'][tio.DATA].cuda()) inputs_train_2 = Variable(data['image2'][tio.DATA].cuda()) mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_train_1, outputs_train_2 = model(inputs_train_1, inputs_train_2) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train_1 = outputs_train_1 score_list_train_2 = outputs_train_2 mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 32: score_list_train_1 = torch.cat((score_list_train_1, outputs_train_1), dim=0) score_list_train_2 = torch.cat((score_list_train_2, outputs_train_2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) max_train1 = torch.max(outputs_train_1, dim=1)[1] max_train2 = torch.max(outputs_train_2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_sup1 = criterion(outputs_train_1, mask_train) loss_train_sup2 = criterion(outputs_train_2, mask_train) loss_train_unsup = criterion(outputs_train_1, max_train2) + criterion(outputs_train_2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup loss_train.backward() optimizer.step() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss_unsup += loss_train_unsup.item() train_loss += loss_train.item() scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train_1 = [torch.zeros_like(score_list_train_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train_1, score_list_train_1) score_list_train_1 = torch.cat(score_gather_list_train_1, dim=0) score_gather_list_train_2 = [torch.zeros_like(score_list_train_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train_2, score_list_train_2) score_list_train_2 = torch.cat(score_gather_list_train_2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train_1, score_list_train_2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'][tio.DATA].cuda()) inputs_val_2 = Variable(data['image2'][tio.DATA].cuda()) mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda()) optimizer.zero_grad() outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2) torch.cuda.empty_cache() if i == 0: score_list_val_1 = outputs_val_1 score_list_val_2 = outputs_val_2 mask_list_val = mask_val else: score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0) score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) loss_val_sup_1 = criterion(outputs_val_1, mask_val) loss_val_sup_2 = criterion(outputs_val_2, mask_val) val_loss_sup_1 += loss_val_sup_1.item() val_loss_sup_2 += loss_val_sup_2.item() torch.cuda.empty_cache() score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1) score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0) score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2) score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT']) torch.cuda.empty_cache() if args.vis: visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_XNet_sb.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_iitnn from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet') parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet') parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS') parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--input1', default='L') parser.add_argument('--input2', default='H') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=2, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('-u', '--unsup_weight', default=5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='xnet_sb', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14 print_num_minus = print_num - 2 print_num_half = int(print_num / 2 - 1) path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)) visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port) # Dataset if args.input1 == 'image': input1_mean = 'MEAN' input1_std = 'STD' else: input1_mean = 'MEAN_' + args.input1 input1_std = 'STD_' + args.input1 if args.input2 == 'image': input2_mean = 'MEAN' input2_std = 'STD' else: input2_mean = 'MEAN_' + args.input2 input2_std = 'STD_' + args.input2 data_transforms = data_transform_2d() data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std]) data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std]) dataset_train = imagefloder_iitnn( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, input1=args.input1, input2=args.input2, data_transform_1=data_transforms['train'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) dataset_val = imagefloder_iitnn( data_dir=args.path_dataset + '/val', input1=args.input1, input2=args.input2, data_transform_1=data_transforms['val'], data_normalize_1=data_normalize_1, data_normalize_2=data_normalize_2, sup=True, num_images=None, ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model1 = get_network(args.network, 3, cfg['NUM_CLASSES']) model2 = get_network(args.network, 1, cfg['NUM_CLASSES']) model1 = model1.cuda() model2 = model2.cuda() model1 = DistributedDataParallel(model1, device_ids=[args.local_rank]) model2 = DistributedDataParallel(model2, device_ids=[args.local_rank]) dist.barrier() # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma) scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1) optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd) exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma) scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2) # Train & Val since = time.time() count_iter = 0 best_model = model1 best_result = 'Result1' best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter - 1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model1.train() model2.train() train_loss_sup_1 = 0.0 train_loss_sup_2 = 0.0 train_loss_unsup = 0.0 train_loss = 0.0 val_loss_sup_1 = 0.0 val_loss_sup_2 = 0.0 unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train_1 = Variable(data['image'].cuda()) inputs_train_2 = Variable(data['image_2'].cuda()) mask_train = Variable(data['mask'].cuda()) optimizer1.zero_grad() optimizer2.zero_grad() outputs_train1 = model1(inputs_train_1) outputs_train2 = model2(inputs_train_2) torch.cuda.empty_cache() if count_iter % args.display_iter == 0: if i == 0: score_list_train1 = outputs_train1 score_list_train2 = outputs_train2 mask_list_train = mask_train # else: elif 0 < i <= num_batches['train_sup'] / 4: score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0) score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) max_train1 = torch.max(outputs_train1, dim=1)[1] max_train2 = torch.max(outputs_train2, dim=1)[1] max_train1 = max_train1.long() max_train2 = max_train2.long() loss_train_sup1 = criterion(outputs_train1, mask_train) loss_train_sup2 = criterion(outputs_train2, mask_train) loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1) loss_train_unsup = loss_train_unsup * unsup_weight loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup loss_train.backward() optimizer1.step() optimizer2.step() train_loss_sup_1 += loss_train_sup1.item() train_loss_sup_2 += loss_train_sup2.item() train_loss_unsup += loss_train_unsup.item() train_loss += loss_train.item() scheduler_warmup1.step() scheduler_warmup2.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train1, score_list_train1) score_list_train1 = torch.cat(score_gather_list_train1, dim=0) score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train2, score_list_train2) score_list_train2 = torch.cat(score_gather_list_train2, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half) train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half) torch.cuda.empty_cache() with torch.no_grad(): model1.eval() model2.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_1 = Variable(data['image'].cuda()) inputs_val_2 = Variable(data['image_2'].cuda()) mask_val = Variable(data['mask'].cuda()) name_val = data['ID'] optimizer1.zero_grad() optimizer2.zero_grad() outputs_val1 = model1(inputs_val_1) outputs_val2 = model2(inputs_val_2) torch.cuda.empty_cache() if i == 0: score_list_val1 = outputs_val1 score_list_val2 = outputs_val2 mask_list_val = mask_val name_list_val = name_val else: score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0) score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) loss_val_sup1 = criterion(outputs_val1, mask_val) loss_val_sup2 = criterion(outputs_val2, mask_val) val_loss_sup_1 += loss_val_sup1.item() val_loss_sup_2 += loss_val_sup2.item() torch.cuda.empty_cache() score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val1, score_list_val1) score_list_val1 = torch.cat(score_gather_list_val1, dim=0) score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val2, score_list_val2) score_list_val2 = torch.cat(score_gather_list_val2, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half) val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half) best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE']) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2) visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2) visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_alnet.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_2d, data_normalize_2d, data_transform_aerial_lanenet from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_aerial_lanenet from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup') parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup') parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--sup_mark', default='100', help='20, 100') parser.add_argument('-b', '--batch_size', default=32, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='alnet', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)) visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port) # Dataset data_transforms = data_transform_2d() data_normalize = data_normalize_2d(cfg['MEAN'], cfg['STD']) data_normalize_l1 = data_transform_aerial_lanenet(64, 64) data_normalize_l2 = data_transform_aerial_lanenet(32, 32) data_normalize_l3 = data_transform_aerial_lanenet(16, 16) data_normalize_l4 = data_transform_aerial_lanenet(8, 8) dataset_train = imagefloder_aerial_lanenet( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, data_transform=data_transforms['train'], data_normalize=data_normalize, data_normalize_l1=data_normalize_l1, data_normalize_l2=data_normalize_l2, data_normalize_l3=data_normalize_l3, data_normalize_l4=data_normalize_l4 ) dataset_val = imagefloder_aerial_lanenet( data_dir=args.path_dataset + '/val', data_transform=data_transforms['val'], data_normalize=data_normalize, data_normalize_l1=data_normalize_l1, data_normalize_l2=data_normalize_l2, data_normalize_l3=data_normalize_l3, data_normalize_l4=data_normalize_l4 ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss = 0.0 val_loss = 0.0 dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train = Variable(data['image'].cuda()) inputs_train_l1 = Variable(data['image_l1'].cuda()) inputs_train_l2 = Variable(data['image_l2'].cuda()) inputs_train_l3 = Variable(data['image_l3'].cuda()) inputs_train_l4 = Variable(data['image_l4'].cuda()) mask_train = Variable(data['mask'].cuda()) optimizer.zero_grad() outputs_train = model(inputs_train, inputs_train_l1, inputs_train_l2, inputs_train_l3, inputs_train_l4) torch.cuda.empty_cache() loss_train = criterion(outputs_train, mask_train) loss_train.backward() optimizer.step() train_loss += loss_train.item() if count_iter % args.display_iter == 0: if i == 0: score_list_train = outputs_train mask_list_train = mask_train else: # elif 0 < i <= num_batches['train_sup'] / 16: score_list_train = torch.cat((score_list_train, outputs_train), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train, score_list_train) score_list_train = torch.cat(score_gather_list_train, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus) train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val = Variable(data['image'].cuda()) inputs_val_l1 = Variable(data['image_l1'].cuda()) inputs_val_l2 = Variable(data['image_l2'].cuda()) inputs_val_l3 = Variable(data['image_l3'].cuda()) inputs_val_l4 = Variable(data['image_l4'].cuda()) mask_val = Variable(data['mask'].cuda()) name_val = data['ID'] optimizer.zero_grad() outputs_val = model(inputs_val, inputs_val_l1, inputs_val_l2, inputs_val_l3, inputs_val_l4) torch.cuda.empty_cache() loss_val = criterion(outputs_val, mask_val) val_loss += loss_val.item() if i == 0: score_list_val = outputs_val mask_list_val = mask_val name_list_val = name_val else: score_list_val = torch.cat((score_list_val, outputs_val), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) torch.cuda.empty_cache() score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val, score_list_val) score_list_val = torch.cat(score_gather_list_val, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus) val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list) visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num) ================================================ FILE: train_sup_wds.py ================================================ from torchvision import transforms, datasets import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable from torch.utils.data import DataLoader import argparse import time import os import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.backends import cudnn import random from config.dataset_config.dataset_cfg import dataset_cfg from config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup from config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup from config.warmup_config.warmup import GradualWarmupScheduler from config.augmentation.online_aug import data_transform_2d, data_normalize_2d from loss.loss_function import segmentation_loss from models.getnetwork import get_network from dataload.dataset_2d import imagefloder_wds from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) def init_seeds(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup') parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup') parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI') parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS') parser.add_argument('--sup_mark', default='100') parser.add_argument('-b', '--batch_size', default=32, type=int) parser.add_argument('-e', '--num_epochs', default=200, type=int) parser.add_argument('-s', '--step_size', default=50, type=int) parser.add_argument('-l', '--lr', default=0.5, type=float) parser.add_argument('-g', '--gamma', default=0.5, type=float) parser.add_argument('--loss', default='dice', type=str) parser.add_argument('-w', '--warm_up_duration', default=20) parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--wd', default=-5, type=float, help='weight decay pow') parser.add_argument('-i', '--display_iter', default=5, type=int) parser.add_argument('-n', '--network', default='wds', type=str) parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3') parser.add_argument('-v', '--vis', default=True, help='need visualization or not') parser.add_argument('--visdom_port', default=16672, help='16672') args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rank = torch.distributed.get_rank() ngpus_per_node = torch.cuda.device_count() init_seeds(rank + 1) dataset_name = args.dataset_name cfg = dataset_cfg(dataset_name) print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7 print_num_minus = print_num - 2 path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_trained_models) and rank == args.rank_index: os.mkdir(path_trained_models) path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1]) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark) if not os.path.exists(path_seg_results) and rank == args.rank_index: os.mkdir(path_seg_results) if args.vis and rank == args.rank_index: visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)) visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port) # Dataset data_transforms = data_transform_2d() data_normalize_LL = data_normalize_2d(cfg['MEAN_LL'], cfg['STD_LL']) data_normalize_LH = data_normalize_2d(cfg['MEAN_LH'], cfg['STD_LH']) data_normalize_HL = data_normalize_2d(cfg['MEAN_HL'], cfg['STD_HL']) data_normalize_HH = data_normalize_2d(cfg['MEAN_HH'], cfg['STD_HH']) dataset_train = imagefloder_wds( data_dir=args.path_dataset + '/train_sup_' + args.sup_mark, data_transform_1=data_transforms['train'], data_normalize_LL=data_normalize_LL, data_normalize_LH=data_normalize_LH, data_normalize_HL=data_normalize_HL, data_normalize_HH=data_normalize_HH ) dataset_val = imagefloder_wds( data_dir=args.path_dataset + '/val', data_transform_1=data_transforms['val'], data_normalize_LL=data_normalize_LL, data_normalize_LH=data_normalize_LH, data_normalize_HL=data_normalize_HL, data_normalize_HH=data_normalize_HH ) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True) val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False) dataloaders = dict() dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler) dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler) num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])} # Model model = get_network(args.network, 1, cfg['NUM_CLASSES']) model = model.cuda() model = DistributedDataParallel(model, device_ids=[args.local_rank]) # Training Strategy criterion = segmentation_loss(args.loss, False).cuda() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler) # Train & Val since = time.time() count_iter = 0 best_val_eval_list = [0 for i in range(4)] for epoch in range(args.num_epochs): count_iter += 1 if (count_iter-1) % args.display_iter == 0: begin_time = time.time() dataloaders['train'].sampler.set_epoch(epoch) model.train() train_loss = 0.0 val_loss = 0.0 dist.barrier() for i, data in enumerate(dataloaders['train']): inputs_train_LL = Variable(data['image_LL'].cuda()) inputs_train_LH = Variable(data['image_LH'].cuda()) inputs_train_HL = Variable(data['image_HL'].cuda()) inputs_train_HH = Variable(data['image_HH'].cuda()) mask_train = Variable(data['mask'].cuda()) optimizer.zero_grad() outputs_train = model(inputs_train_LL, inputs_train_LH, inputs_train_HL, inputs_train_HH) torch.cuda.empty_cache() loss_train = criterion(outputs_train, mask_train) loss_train.backward() optimizer.step() train_loss += loss_train.item() if count_iter % args.display_iter == 0: if i == 0: score_list_train = outputs_train mask_list_train = mask_train else: # elif 0 < i <= num_batches['train_sup'] / 16: score_list_train = torch.cat((score_list_train, outputs_train), dim=0) mask_list_train = torch.cat((mask_list_train, mask_train), dim=0) scheduler_warmup.step() torch.cuda.empty_cache() if count_iter % args.display_iter == 0: score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_train, score_list_train) score_list_train = torch.cat(score_gather_list_train, dim=0) mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_train, mask_list_train) mask_list_train = torch.cat(mask_gather_list_train, dim=0) if rank == args.rank_index: torch.cuda.empty_cache() print('=' * print_num) print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|') train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus) train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus) torch.cuda.empty_cache() with torch.no_grad(): model.eval() for i, data in enumerate(dataloaders['val']): # if 0 <= i <= num_batches['val']: inputs_val_LL = Variable(data['image_LL'].cuda()) inputs_val_LH = Variable(data['image_LH'].cuda()) inputs_val_HL = Variable(data['image_HL'].cuda()) inputs_val_HH = Variable(data['image_HH'].cuda()) mask_val = Variable(data['mask'].cuda()) name_val = data['ID'] optimizer.zero_grad() outputs_val = model(inputs_val_LH, inputs_val_LH, inputs_val_HL, inputs_val_HH) torch.cuda.empty_cache() loss_val = criterion(outputs_val, mask_val) val_loss += loss_val.item() if i == 0: score_list_val = outputs_val mask_list_val = mask_val name_list_val = name_val else: score_list_val = torch.cat((score_list_val, outputs_val), dim=0) mask_list_val = torch.cat((mask_list_val, mask_val), dim=0) name_list_val = np.append(name_list_val, name_val, axis=0) torch.cuda.empty_cache() score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(score_gather_list_val, score_list_val) score_list_val = torch.cat(score_gather_list_val, dim=0) mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)] torch.distributed.all_gather(mask_gather_list_val, mask_list_val) mask_list_val = torch.cat(mask_gather_list_val, dim=0) name_gather_list_val = [None for _ in range(ngpus_per_node)] torch.distributed.all_gather_object(name_gather_list_val, name_list_val) name_list_val = np.concatenate(name_gather_list_val, axis=0) torch.cuda.empty_cache() if rank == args.rank_index: val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus) val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus) best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network) torch.cuda.empty_cache() if args.vis: draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list) visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc) visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3]) print('-' * print_num) print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|') torch.cuda.empty_cache() torch.cuda.empty_cache() if rank == args.rank_index: time_elapsed = time.time() - since m, s = divmod(time_elapsed, 60) h, m = divmod(m, 60) print('=' * print_num) print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|') print('-' * print_num) print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus) print('=' * print_num)