================================================
FILE: Datasets/ISIC2018.py
================================================
import os
import PIL
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from os import listdir
from os.path import join
from PIL import Image
from utils.transform import itensity_normalize
from torch.utils.data.dataset import Dataset
class ISIC2018_dataset(Dataset):
def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
folder='folder0', train_type='train', transform=None):
self.transform = transform
self.train_type = train_type
self.folder_file = './Datasets/' + folder
if self.train_type in ['train', 'validation', 'test']:
# this is for cross validation
with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
'r') as f:
self.image_list = f.readlines()
self.image_list = [item.replace('\n', '') for item in self.image_list]
self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
# self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
# listdir(join(dataset_folder, self.train_type, 'image'))])
# self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
# listdir(join(dataset_folder, self.train_type, 'label'))])
else:
print("Choosing type error, You have to choose the loading data type including: train, validation, test")
assert len(self.folder) == len(self.mask)
def __getitem__(self, item: int):
image = np.load(self.folder[item])
label = np.load(self.mask[item])
sample = {'image': image, 'label': label}
if self.transform is not None:
# TODO: transformation to argument datasets
sample = self.transform(sample, self.train_type)
return sample['image'], sample['label']
def __len__(self):
return len(self.folder)
# a = ISIC2018_dataset()
================================================
FILE: Datasets/folder0/folder0_test.list
================================================
ISIC_0010854.npy
================================================
FILE: Models/__init__.py
================================================
================================================
FILE: Models/layers/__init__.py
================================================
================================================
FILE: Models/layers/channel_attention_layer.py
================================================
import torch.nn as nn
# # SE block add to U-net
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
class SE_Conv_Block(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False):
super(SE_Conv_Block, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes * 2)
self.bn2 = nn.BatchNorm2d(planes * 2)
self.conv3 = conv3x3(planes * 2, planes)
self.bn3 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.dropout = drop_out
if planes <= 16:
self.globalAvgPool = nn.AvgPool2d((224, 300), stride=1) # (224, 300) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((224, 300), stride=1)
elif planes == 32:
self.globalAvgPool = nn.AvgPool2d((112, 150), stride=1) # (112, 150) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((112, 150), stride=1)
elif planes == 64:
self.globalAvgPool = nn.AvgPool2d((56, 75), stride=1) # (56, 75) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((56, 75), stride=1)
elif planes == 128:
self.globalAvgPool = nn.AvgPool2d((28, 37), stride=1) # (28, 37) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((28, 37), stride=1)
elif planes == 256:
self.globalAvgPool = nn.AvgPool2d((14, 18), stride=1) # (14, 18) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((14, 18), stride=1)
self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2))
self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2)
self.sigmoid = nn.Sigmoid()
self.downchannel = None
if inplanes != planes:
self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * 2),)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downchannel is not None:
residual = self.downchannel(x)
original_out = out
out1 = out
# For global average pool
out = self.globalAvgPool(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
out = out.view(out.size(0), out.size(1), 1, 1)
avg_att = out
out = out * original_out
# For global maximum pool
out1 = self.globalMaxPool(out1)
out1 = out1.view(out1.size(0), -1)
out1 = self.fc1(out1)
out1 = self.relu(out1)
out1 = self.fc2(out1)
out1 = self.sigmoid(out1)
out1 = out1.view(out1.size(0), out1.size(1), 1, 1)
max_att = out1
out1 = out1 * original_out
att_weight = avg_att + max_att
out += out1
out += residual
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.relu(out)
if self.dropout:
out = nn.Dropout2d(0.5)(out)
return out, att_weight
================================================
FILE: Models/layers/grid_attention_layer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from Models.networks_other import init_weights
class _GridAttentionBlockND(nn.Module):
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
sub_sample_factor=(2,2,2)):
super(_GridAttentionBlockND, self).__init__()
assert dimension in [2, 3]
assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual']
# Downsampling rate for the input featuremap
if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension
# Default parameter set
self.mode = mode
self.dimension = dimension
self.sub_sample_kernel_size = self.sub_sample_factor
# Number of channels (pixel dimensions)
self.in_channels = in_channels
self.gating_channels = gating_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
bn = nn.BatchNorm3d
self.upsample_mode = 'trilinear'
elif dimension == 2:
conv_nd = nn.Conv2d
bn = nn.BatchNorm2d
self.upsample_mode = 'bilinear'
else:
raise NotImplemented
# Output transform
self.W = nn.Sequential(
conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
bn(self.in_channels),
)
# Theta^T * x_ij + Phi^T * gating_signal + bias
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True)
self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
kernel_size=(1, 1), stride=1, padding=0, bias=True)
self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
# Define the operation
if mode == 'concatenation':
self.operation_function = self._concatenation
elif mode == 'concatenation_debug':
self.operation_function = self._concatenation_debug
elif mode == 'concatenation_residual':
self.operation_function = self._concatenation_residual
else:
raise NotImplementedError('Unknown operation function.')
def forward(self, x, g):
'''
:param x: (b, c, t, h, w)
:param g: (b, g_d)
:return:
'''
output = self.operation_function(x, g)
return output
def _concatenation(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.relu(theta_x + phi_g, inplace=True)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
sigm_psi_f = F.sigmoid(self.psi(f))
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
def _concatenation_debug(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.softplus(theta_x + phi_g)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
sigm_psi_f = F.sigmoid(self.psi(f))
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
def _concatenation_residual(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.relu(theta_x + phi_g, inplace=True)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
f = self.psi(f).view(batch_size, 1, -1)
sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:])
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
class GridAttentionBlock2D(_GridAttentionBlockND):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(2, 2)):
super(GridAttentionBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
)
class GridAttentionBlock3D(_GridAttentionBlockND):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(2,2,2)):
super(GridAttentionBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
)
class _GridAttentionBlockND_TORR(nn.Module):
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'):
super(_GridAttentionBlockND_TORR, self).__init__()
assert dimension in [2, 3]
assert mode in ['concatenation', 'concatenation_softmax',
'concatenation_sigmoid', 'concatenation_mean',
'concatenation_range_normalise', 'concatenation_mean_flow']
# Default parameter set
self.mode = mode
self.dimension = dimension
self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension
self.sub_sample_kernel_size = self.sub_sample_factor
# Number of channels (pixel dimensions)
self.in_channels = in_channels
self.gating_channels = gating_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
bn = nn.BatchNorm3d
self.upsample_mode = 'trilinear'
elif dimension == 2:
conv_nd = nn.Conv2d
bn = nn.BatchNorm2d
self.upsample_mode = 'bilinear'
else:
raise NotImplemented
# initialise id functions
# Theta^T * x_ij + Phi^T * gating_signal + bias
self.W = lambda x: x
self.theta = lambda x: x
self.psi = lambda x: x
self.phi = lambda x: x
self.nl1 = lambda x: x
if use_W:
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
bn(self.in_channels),
)
else:
self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
if use_theta:
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
if use_phi:
self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
if use_psi:
self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
if nonlinearity1:
if nonlinearity1 == 'relu':
self.nl1 = lambda x: F.relu(x, inplace=True)
if 'concatenation' in mode:
self.operation_function = self._concatenation
else:
raise NotImplementedError('Unknown operation function.')
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
if use_psi and self.mode == 'concatenation_sigmoid':
nn.init.constant(self.psi.bias.data, 3.0)
if use_psi and self.mode == 'concatenation_softmax':
nn.init.constant(self.psi.bias.data, 10.0)
# if use_psi and self.mode == 'concatenation_mean':
# nn.init.constant(self.psi.bias.data, 3.0)
# if use_psi and self.mode == 'concatenation_range_normalise':
# nn.init.constant(self.psi.bias.data, 3.0)
parallel = False
if parallel:
if use_W: self.W = nn.DataParallel(self.W)
if use_phi: self.phi = nn.DataParallel(self.phi)
if use_psi: self.psi = nn.DataParallel(self.psi)
if use_theta: self.theta = nn.DataParallel(self.theta)
def forward(self, x, g):
'''
:param x: (b, c, t, h, w)
:param g: (b, g_d)
:return:
'''
output = self.operation_function(x, g)
return output
def _concatenation(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
#############################
# compute compatibility score
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w)
# phi => (b, c, t, h, w) -> (b, i_c, t, h, w)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = theta_x + phi_g
f = self.nl1(f)
psi_f = self.psi(f)
############################################
# normalisation -- scale compatibility score
# psi^T . f -> (b, 1, t/s1, h/s2, w/s3)
if self.mode == 'concatenation_softmax':
sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2)
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_mean':
psi_f_flat = psi_f.view(batch_size, 1, -1)
psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6)
psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat)
sigm_psi_f = psi_f_flat / psi_f_sum
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_mean_flow':
psi_f_flat = psi_f.view(batch_size, 1, -1)
ss = psi_f_flat.shape
psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1)
psi_f_flat = psi_f_flat - psi_f_min
psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat)
sigm_psi_f = psi_f_flat / psi_f_sum
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_range_normalise':
psi_f_flat = psi_f.view(batch_size, 1, -1)
ss = psi_f_flat.shape
psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)
psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)
sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat)
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_sigmoid':
sigm_psi_f = F.sigmoid(psi_f)
else:
raise NotImplementedError
# sigm_psi_f is attention map! upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(1,1), bn_layer=True,
use_W=True, use_phi=True, use_theta=True, use_psi=True,
nonlinearity1='relu'):
super(GridAttentionBlock2D_TORR, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer,
use_W=use_W,
use_phi=use_phi,
use_theta=use_theta,
use_psi=use_psi,
nonlinearity1=nonlinearity1)
class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(1,1,1), bn_layer=True):
super(GridAttentionBlock3D_TORR, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class MultiAttentionBlock(nn.Module):
def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
super(MultiAttentionBlock, self).__init__()
self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,
inter_channels=inter_size, mode=nonlocal_mode,
sub_sample_factor=sub_sample_factor)
self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,
inter_channels=inter_size, mode=nonlocal_mode,
sub_sample_factor=sub_sample_factor)
self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(in_size),
nn.ReLU(inplace=True))
# initialise the blocks
for m in self.children():
if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue
init_weights(m, init_type='kaiming')
def forward(self, input, gating_signal):
gate_1, attention_1 = self.gate_block_1(input, gating_signal)
gate_2, attention_2 = self.gate_block_2(input, gating_signal)
return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)
if __name__ == '__main__':
from torch.autograd import Variable
mode_list = ['concatenation']
for mode in mode_list:
img = Variable(torch.rand(2, 16, 10, 10, 10))
gat = Variable(torch.rand(2, 64, 4, 4, 4))
net = GridAttentionBlock3D(in_channels=16, inter_channels=16, gating_channels=64, mode=mode, sub_sample_factor=(2,2,2))
out, sigma = net(img, gat)
print(out.size())
================================================
FILE: Models/layers/modules.py
================================================
import torch
import torch.nn as nn
def conv1x1(in_planes, out_planes, stride=1, bias=False):
"1x1 convolution"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
# conv_block(nn.Module) for U-net convolution block
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out, drop_out=False):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
self.dropout = drop_out
def forward(self, x):
x = self.conv(x)
if self.dropout:
x = nn.Dropout2d(0.5)(x)
return x
# # UpCat(nn.Module) for U-net UP convolution
class UpCat(nn.Module):
def __init__(self, in_feat, out_feat, is_deconv=True):
super(UpCat, self).__init__()
if is_deconv:
self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)
else:
self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, inputs, down_outputs):
# TODO: Upsampling required after deconv?
outputs = self.up(down_outputs)
offset = inputs.size()[3] - outputs.size()[3]
if offset == 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(
3).cuda()
outputs = torch.cat([outputs, addition], dim=3)
elif offset > 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()
outputs = torch.cat([outputs, addition], dim=3)
out = torch.cat([inputs, outputs], dim=1)
return out
# # UpCatconv(nn.Module) for up convolution
class UpCatconv(nn.Module):
def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False):
super(UpCatconv, self).__init__()
if is_deconv:
self.conv = conv_block(in_feat, out_feat, drop_out=drop_out)
self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)
else:
self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out)
self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, inputs, down_outputs):
# TODO: Upsampling required after deconv
outputs = self.up(down_outputs)
offset = inputs.size()[3] - outputs.size()[3]
if offset == 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(
3).cuda()
outputs = torch.cat([outputs, addition], dim=3)
elif offset > 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()
outputs = torch.cat([outputs, addition], dim=3)
out = self.conv(torch.cat([inputs, outputs], dim=1))
return out
# # UnetGridGatingSignal3(nn.Module)
class UnetGridGatingSignal3(nn.Module):
def __init__(self, in_size, out_size, kernel_size=(1, 1), is_batchnorm=True):
super(UnetGridGatingSignal3, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
nn.ReLU(inplace=True),
)
def forward(self, inputs):
outputs = self.conv1(inputs)
return outputs
class UnetDsv3(nn.Module):
def __init__(self, in_size, out_size, scale_factor):
super(UnetDsv3, self).__init__()
self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0),
nn.Upsample(size=scale_factor, mode='bilinear'), )
def forward(self, input):
return self.dsv(input)
================================================
FILE: Models/layers/nonlocal_layer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from Models.networks_other import init_weights
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
sub_sample_factor=4, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']
# print('Dimension: %d, mode: %s' % (dimension, mode))
self.mode = mode
self.dimension = dimension
self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool = nn.MaxPool3d
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool = nn.MaxPool2d
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool = nn.MaxPool1d
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant(self.W[1].weight, 0)
nn.init.constant(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
self.theta = None
self.phi = None
if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if mode in ['concatenation']:
self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)
self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)
elif mode in ['concat_proper', 'concat_proper_down']:
self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,
padding=0, bias=True)
if mode == 'embedded_gaussian':
self.operation_function = self._embedded_gaussian
elif mode == 'dot_product':
self.operation_function = self._dot_product
elif mode == 'gaussian':
self.operation_function = self._gaussian
elif mode == 'concatenation':
self.operation_function = self._concatenation
elif mode == 'concat_proper':
self.operation_function = self._concatenation_proper
elif mode == 'concat_proper_down':
self.operation_function = self._concatenation_proper_down
else:
raise NotImplementedError('Unknown operation function.')
if any(ss > 1 for ss in self.sub_sample_factor):
self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))
if self.phi is None:
self.phi = max_pool(kernel_size=sub_sample_factor)
else:
self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))
if mode == 'concat_proper_down':
self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
output = self.operation_function(x)
return output
def _embedded_gaussian(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
# (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _gaussian(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = x.view(batch_size, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample_factor > 1:
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
else:
phi_x = x.view(batch_size, self.in_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _dot_product(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
# theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw)
# phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw)
# f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw)
f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \
self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1))
f = F.relu(f, inplace=True)
# Normalise the relations
N = f.size(-1)
f_div_c = f / N
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation_proper(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
# phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
# f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
f = F.relu(f, inplace=True)
# psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
f = torch.squeeze(self.psi(f), dim=1)
# Normalise the relations
f_div_c = F.softmax(f, dim=1)
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation_proper_down(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
theta_x = self.theta(x)
downsampled_size = theta_x.size()
theta_x = theta_x.view(batch_size, self.inter_channels, -1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
# phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
# f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
f = F.relu(f, inplace=True)
# psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
f = torch.squeeze(self.psi(f), dim=1)
# Normalise the relations
f_div_c = F.softmax(f, dim=1)
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:])
# upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3)
y = F.upsample(y, size=x.size()[2:], mode='trilinear')
# attention block output
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
if __name__ == '__main__':
from torch.autograd import Variable
mode_list = ['concatenation']
#mode_list = ['embedded_gaussian', 'gaussian', 'dot_product', ]
for mode in mode_list:
print(mode)
img = Variable(torch.zeros(2, 4, 5))
net = NONLocalBlock1D(4, mode=mode, sub_sample_factor=2)
out = net(img)
print(out.size())
img = Variable(torch.zeros(2, 4, 5, 3))
net = NONLocalBlock2D(4, mode=mode, sub_sample_factor=1, bn_layer=False)
out = net(img)
print(out.size())
img = Variable(torch.zeros(2, 4, 5, 4, 5))
net = NONLocalBlock3D(4, mode=mode)
out = net(img)
print(out.size())
================================================
FILE: Models/layers/scale_attention_layer.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
def conv1x1(in_planes, out_planes, stride=1, bias=False):
"1x1 convolution"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
# # SE block add to U-net
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
# # CBAM Convolutional block attention module
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
elif pool_type == 'lp':
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(lp_pool)
elif pool_type == 'lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
# scalecoe = F.sigmoid(channel_att_sum)
channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4)
avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2)
avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16)
scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale, scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
# spa_scale = scale.expand_as(x)
# print(spa_scale.shape)
return x * scale, scale
class SpatialAtten(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, stride=1):
super(SpatialAtten, self).__init__()
self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride,
padding=(kernel_size-1) // 2, relu=True)
self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride,
padding=0, relu=True, bn=False)
def forward(self, x):
residual = x
x_out = self.conv1(x)
x_out = self.conv2(x_out)
spatial_att = F.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3)
spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape(
spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4])
x_out = residual * spatial_att
x_out += residual
return x_out, spatial_att
class Scale_atten_block(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(Scale_atten_block, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio)
def forward(self, x):
x_out, ca_atten = self.ChannelGate(x)
if not self.no_spatial:
x_out, sa_atten = self.SpatialGate(x_out)
return x_out, ca_atten, sa_atten
class scale_atten_convblock(nn.Module):
def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False):
super(scale_atten_convblock, self).__init__()
# if stride != 1 or in_size != out_size:
# downsample = nn.Sequential(
# nn.Conv2d(in_size, out_size,
# kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(out_size),
# )
self.downsample = downsample
self.stride = stride
self.no_spatial = no_spatial
self.dropout = drop_out
self.relu = nn.ReLU(inplace=True)
self.conv3 = conv3x3(in_size, out_size)
self.bn3 = nn.BatchNorm2d(out_size)
if use_cbam:
self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size
else:
self.cbam = None
def forward(self, x):
residual = x
if self.downsample is not None:
residual = self.downsample(x)
if not self.cbam is None:
out, scale_c_atten, scale_s_atten = self.cbam(x)
# scale_c_atten = nn.Sigmoid()(scale_c_atten)
# scale_s_atten = nn.Sigmoid()(scale_s_atten)
# scale_atten = channel_atten_c * spatial_atten_s
# scale_max = torch.argmax(scale_atten, dim=1, keepdim=True)
# scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8)
# scale_max_soft = scale_max_soft.permute(0, 3, 1, 2)
# scale_atten_soft = scale_atten * scale_max_soft
out += residual
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.relu(out)
if self.dropout:
out = nn.Dropout2d(0.5)(out)
return out
================================================
FILE: Models/networks/network.py
================================================
import torch
import torch.nn as nn
from Models.layers.modules import conv_block, UpCat, UpCatconv, UnetDsv3, UnetGridGatingSignal3
from Models.layers.grid_attention_layer import GridAttentionBlock2D, MultiAttentionBlock
from Models.layers.channel_attention_layer import SE_Conv_Block
from Models.layers.scale_attention_layer import scale_atten_convblock
from Models.layers.nonlocal_layer import NONLocalBlock2D
class Comprehensive_Atten_Unet(nn.Module):
def __init__(self, args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True,
nonlocal_mode='concatenation', attention_dsample=(1, 1)):
super(Comprehensive_Atten_Unet, self).__init__()
self.args = args
self.is_deconv = is_deconv
self.in_channels = in_ch
self.num_classes = n_classes
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
self.out_size = args.out_size
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# downsampling
self.conv1 = conv_block(self.in_channels, filters[0])
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = conv_block(filters[0], filters[1])
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = conv_block(filters[1], filters[2])
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = conv_block(filters[2], filters[3], drop_out=True)
self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.center = conv_block(filters[3], filters[4], drop_out=True)
# attention blocks
# self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1],
# inter_channels=filters[0])
self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)
self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)
self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4)
# upsampling
self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv)
self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv)
self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv)
self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv)
self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True)
self.up3 = SE_Conv_Block(filters[3], filters[2])
self.up2 = SE_Conv_Block(filters[2], filters[1])
self.up1 = SE_Conv_Block(filters[1], filters[0])
# deep supervision
self.dsv4 = UnetDsv3(in_size=filters[3], out_size=4, scale_factor=self.out_size)
self.dsv3 = UnetDsv3(in_size=filters[2], out_size=4, scale_factor=self.out_size)
self.dsv2 = UnetDsv3(in_size=filters[1], out_size=4, scale_factor=self.out_size)
self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=4, kernel_size=1)
self.scale_att = scale_atten_convblock(in_size=16, out_size=4)
# final conv (without any concat)
self.final = nn.Sequential(nn.Conv2d(4, n_classes, kernel_size=1), nn.Softmax2d())
def forward(self, inputs):
# Feature Extraction
conv1 = self.conv1(inputs)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
# Gating Signal Generation
center = self.center(maxpool4)
# Attention Mechanism
# Upscaling Part (Decoder)
up4 = self.up_concat4(conv4, center)
g_conv4 = self.nonlocal4_2(up4)
up4, att_weight4 = self.up4(g_conv4)
g_conv3, att3 = self.attentionblock3(conv3, up4)
# atten3_map = att3.cpu().detach().numpy().astype(np.float)
# atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2],
# 300 / atten3_map.shape[3]], order=0)
up3 = self.up_concat3(g_conv3, up4)
up3, att_weight3 = self.up3(up3)
g_conv2, att2 = self.attentionblock2(conv2, up3)
# atten2_map = att2.cpu().detach().numpy().astype(np.float)
# atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2],
# 300 / atten2_map.shape[3]], order=0)
up2 = self.up_concat2(g_conv2, up3)
up2, att_weight2 = self.up2(up2)
# g_conv1, att1 = self.attentionblock1(conv1, up2)
# atten1_map = att1.cpu().detach().numpy().astype(np.float)
# atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2],
# 300 / atten1_map.shape[3]], order=0)
up1 = self.up_concat1(conv1, up2)
up1, att_weight1 = self.up1(up1)
# Deep Supervision
dsv4 = self.dsv4(up4)
dsv3 = self.dsv3(up3)
dsv2 = self.dsv2(up2)
dsv1 = self.dsv1(up1)
dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1)
out = self.scale_att(dsv_cat)
out = self.final(out)
return out
================================================
FILE: Models/networks_other.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
import time
import numpy as np
###############################################################################
# Functions
###############################################################################
def weights_init_normal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
#print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def adjust_learning_rate(optimizer, lr):
"""Sets the learning rate to a fixed number"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def get_scheduler(optimizer, opt):
print('opt.lr_policy = [{}]'.format(opt.lr_policy))
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.5)
elif opt.lr_policy == 'step2':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
print('schedular=plateau')
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.01, patience=5)
elif opt.lr_policy == 'plateau2':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'step_warmstart':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 100:
lr_l = 1
elif 100 <= epoch < 200:
lr_l = 0.1
elif 200 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step_warmstart2':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 50:
lr_l = 1
elif 50 <= epoch < 100:
lr_l = 0.1
elif 100 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def get_n_parameters(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
return num_params
def measure_fp_bp_time(model, x, y):
# synchronize gpu time and measure fp
torch.cuda.synchronize()
t0 = time.time()
y_pred = model(x)
torch.cuda.synchronize()
elapsed_fp = time.time() - t0
if isinstance(y_pred, tuple):
y_pred = sum(y_p.sum() for y_p in y_pred)
else:
y_pred = y_pred.sum()
# zero gradients, synchronize time and measure
model.zero_grad()
t0 = time.time()
#y_pred.backward(y)
y_pred.backward()
torch.cuda.synchronize()
elapsed_bp = time.time() - t0
return elapsed_fp, elapsed_bp
def benchmark_fp_bp_time(model, x, y, n_trial=1000):
# transfer the model on GPU
model.cuda()
# DRY RUNS
for i in range(10):
_, _ = measure_fp_bp_time(model, x, y)
print('DONE WITH DRY RUNS, NOW BENCHMARKING')
# START BENCHMARKING
t_forward = []
t_backward = []
print('trial: {}'.format(n_trial))
for i in range(n_trial):
t_fp, t_bp = measure_fp_bp_time(model, x, y)
t_forward.append(t_fp)
t_backward.append(t_bp)
# free memory
del model
return np.mean(t_forward), np.mean(t_backward)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
================================================
FILE: README.md
================================================
## CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation
This repository provides the code for "CA-Net: Comprehensive attention Convolutional Neural Networks for Explainable Medical Image Segmentation". Our work now is available on [Arxiv][paper_link]. Our work is accepted by [TMI][tmi_link].
[paper_link]:https://arxiv.org/pdf/2009.10549.pdf
[tmi_link]:https://ieeexplore.ieee.org/document/9246575

Fig. 1. Structure of CA-Net.

Fig. 2. Skin lesion segmentation.

Fig. 3. Placenta and fetal brain segmentation.
### Requirementss
Some important required packages include:
* [Pytorch][torch_link] version >=0.4.1.
* Visdom
* Python == 3.7
* Some basic python packages such as Numpy.
Follow official guidance to install [Pytorch][torch_link].
[torch_link]:https://pytorch.org/
## Usages
### For skin lesion segmentation
1. First, you can download the dataset at [ISIC 2018][data_link]. We only used ISIC 2018 task1 training dataset, To preprocess the dataset and save as ".npy", run:
[data_link]:https://challenge.isic-archive.com/data#2018
```
python isic_preprocess.py
```
2. For conducting 5-fold cross-validation, split the preprocessed data into 5 fold and save their filenames. run:
```
python create_folder.py
```
2. To train CA-Net in ISIC 2018 (taking 1st-fold validation for example), run:
```
python main.py --data ISIC2018 --val_folder folder1 --id Comp_Atten_Unet
```
3. To evaluate the trained model in ISIC 2018 (we added a test data in folder0, testing the 0th-fold validation for example), run:
```
python validation.py --data ISIC2018 --val_folder folder0 --id Comp_Atten_Unet
```
Our experimental results are shown in the table:

4. You can save the attention weight map in the middle step of the network to '/result' folder. Visualizing the attention weight above the original images, run:
```
python show_fused_heatmap.py
```
Visualzation of spatial attention weight map:

Visualzation of scale attention weight map:

## Citation
If you find our work is helpful for your research, please consider to cite:
```
@article{gu2020net,
title={CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation},
author={Gu, Ran and Wang, Guotai and Song, Tao and Huang, Rui and Aertsen, Michael and Deprest, Jan and Ourselin, S{\'e}bastien and Vercauteren, Tom and Zhang, Shaoting},
journal={IEEE Transactions on Medical Imaging},
year={2020},
publisher={IEEE}
}
```
## Acknowledgement
Part of the code is revised from [Attention-Gate-Networks][AG].
[AG]:https://github.com/ozan-oktay/Attention-Gated-Networks
================================================
FILE: create_folder.py
================================================
import os
import numpy
from random import shuffle
PATH = './data/ISIC2018_Task1_npy_all/image'
SAVE_PATH = './Datasets'
def create_5_floder(folder, save_foler):
file_list = os.listdir(folder)
shuffle(file_list)
for i in range(5):
if i != 0:
pre_test_list = file_list[0:i*518]
else:
pre_test_list = []
test_list = file_list[i*518:(i+1)*518]
if i < 4:
valid_list = file_list[(i+1)*518:(i+1)*518+260]
train_list = file_list[(i+1)*518+260:] + pre_test_list
else:
valid_list = file_list[-4:] + file_list[:256]
train_list = file_list[256:i*518]
if not os.path.isdir(save_foler + '/folder'+str(i+1)):
os.makedirs(save_foler + '/folder'+str(i+1))
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_train.list'), train_list)
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_validation.list'), valid_list)
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_test.list'), test_list)
def text_save(filename, data): # filename: path to write CSV, data: data list to be written.
file = open(filename, 'w+')
for i in range(len(data)):
s = str(data[i]).replace('[', '').replace(']', '')
s = s.replace("'", '').replace(',', '') + '\n'
file.write(s)
file.close()
print("Save {} successfully".format(filename.split('/')[-1]))
if __name__ == "__main__":
create_5_floder(PATH, SAVE_PATH)
================================================
FILE: isic_preprocess.py
================================================
#!/usr/bin/python3
# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection
# -*- coding: utf-8 -*-
# @Author : Ran Gu
import os
import random
import numpy as np
from skimage import io
from PIL import Image
root_dir = 'gr/Skin Segmentation' # change it in your saved original data path
save_dir = './data/ISIC2018_Task1_npy_all'
if __name__ == '__main__':
imgfile = os.path.join(root_dir, 'ISIC2018_Task1-2_Training_Input')
labfile = os.path.join(root_dir, 'ISIC2018_Task1_Training_GroundTruth')
filename = sorted([os.path.join(imgfile, x) for x in os.listdir(imgfile) if x.endswith('.jpg')])
random.shuffle(filename)
labname = [filename[x].replace('ISIC2018_Task1-2_Training_Input', 'ISIC2018_Task1_Training_GroundTruth'
).replace('.jpg', '_segmentation.png') for x in range(len(filename))]
if not os.path.isdir(save_dir):
os.makedirs(save_dir+'/image')
os.makedirs(save_dir+'/label')
for i in range(len(filename)):
fname = filename[i].rsplit('/', maxsplit=1)[-1].split('.')[0]
lname = labname[i].rsplit('/', maxsplit=1)[-1].split('.')[0]
image = Image.open(filename[i])
label = Image.open(labname[i])
image = image.resize((342, 256))
label = label.resize((342, 256))
image = np.array(image)
label = np.array(label)
images_img_filename = os.path.join(save_dir, 'image', fname)
labels_img_filename = os.path.join(save_dir, 'label', lname)
np.save(images_img_filename, image)
np.save(labels_img_filename, label)
print('Successfully saved preprocessed data')
================================================
FILE: main.py
================================================
#!/usr/bin/python3
# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection
# -*- coding: utf-8 -*-
# @Author : Ran Gu
import os
import torch
import math
import visdom
import torch.utils.data as Data
import argparse
import numpy as np
from tqdm import tqdm
from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform
from Models.networks.network import Comprehensive_Atten_Unet
from utils.dice_loss import SoftDiceLoss, get_soft_label, val_dice_fetus, val_dice_isic
from utils.dice_loss import Intersection_over_Union_fetus, Intersection_over_Union_isic
from utils.evaluation import AverageMeter
from utils.binary import assd
from torch.optim.lr_scheduler import StepLR
Test_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'ISIC2018': ISIC2018_transform}
def train(train_loader, model, criterion, optimizer, args, epoch):
losses = AverageMeter()
model.train()
for step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
image = x.float().cuda()
target = y.float().cuda()
output = model(image) # model output
target_soft = get_soft_label(target, args.num_classes) # get soft label
loss = criterion(output, target_soft, args.num_classes) # the dice losses
losses.update(loss.data, image.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % (math.ceil(float(len(train_loader.dataset))/args.batch_size)) == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(train_loader.dataset),
100. * step / len(train_loader), losses=losses))
print('The average loss:{losses.avg:.4f}'.format(losses=losses))
return losses.avg
def valid_fetus(valid_loader, model, criterion, optimizer, args, epoch, minloss):
val_losses = AverageMeter()
val_placenta_dice = AverageMeter()
val_brain_dice = AverageMeter()
model.eval()
for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
image = t.float().cuda()
target = k.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes) # get soft label
target_soft = get_soft_label(target, args.num_classes)
val_loss = criterion(output, target_soft, args.num_classes) # the dice losses
val_losses.update(val_loss.data, image.size(0))
placenta, brain = val_dice_fetus(output_soft, target_soft, args.num_classes) # the dice score
val_placenta_dice.update(placenta.data, image.size(0))
val_brain_dice.update(brain.data, image.size(0))
if step % (math.ceil(float(len(valid_loader.dataset))/args.batch_size)) == 0:
print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader), losses=val_losses))
print('The Placenta Mean Average Dice score: {placenta.avg: .4f}; '
'The Brain Mean Average Dice score: {brain.avg: .4f}; '
'The Average Loss score: {loss.avg: .4f}'.format(
placenta=val_placenta_dice, brain=val_brain_dice, loss=val_losses))
if val_losses.avg < min(minloss):
minloss.append(val_losses.avg)
print(minloss)
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
print('the best model will be saved at {}'.format(modelname))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, modelname)
return val_losses.avg, val_placenta_dice.avg, val_brain_dice.avg
def valid_isic(valid_loader, model, criterion, optimizer, args, epoch, minloss):
val_losses = AverageMeter()
val_isic_dice = AverageMeter()
model.eval()
for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
image = t.float().cuda()
target = k.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
val_loss = criterion(output, target_soft, args.num_classes) # the dice losses
val_losses.update(val_loss.data, image.size(0))
isic = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice score
val_isic_dice.update(isic.data, image.size(0))
if step % (math.ceil(float(len(valid_loader.dataset)) / args.batch_size)) == 0:
print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader),
losses=val_losses))
print('The ISIC Mean Average Dice score: {isic.avg: .4f}; '
'The Average Loss score: {loss.avg: .4f}'.format(
isic=val_isic_dice, loss=val_losses))
if val_losses.avg < min(minloss):
minloss.append(val_losses.avg)
print(minloss)
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
print('the best model will be saved at {}'.format(modelname))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, modelname)
return val_losses.avg, val_isic_dice.avg
def test_fetus(test_loader, model, args):
placenta_dice = []
brain_dice = []
placenta_iou = []
brain_iou = []
placenta_assd = []
brain_assd = []
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
model.eval()
for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
image = img.float().cuda()
target = lab.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
# input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)
label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)
placenta_b_dice, brain_b_dice = val_dice_fetus(output_soft, target_soft, args.num_classes) # the dice accuracy
placenta_b_iou, brain_b_iou = Intersection_over_Union_fetus(output_soft, target_soft, args.num_classes) # the iou accuracy
placenta_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])
brain_b_asd = assd(output_arr[:, :, :, 2], label_arr[:, :, :, 2])
pla_dice_np = placenta_b_dice.data.cpu().numpy()
bra_iou_np = brain_b_iou.data.cpu().numpy()
bra_dice_np = brain_b_dice.data.cpu().numpy()
pla_iou_np = placenta_b_iou.data.cpu().numpy()
placenta_dice.append(pla_dice_np)
brain_dice.append(bra_dice_np)
placenta_iou.append(pla_iou_np)
brain_iou.append(bra_iou_np)
placenta_assd.append(placenta_b_asd)
brain_assd.append(brain_b_asd)
placenta_dice_mean = np.average(placenta_dice)
placenta_dice_std = np.std(placenta_dice)
brain_dice_mean = np.average(brain_dice)
brain_dice_std = np.std(brain_dice)
placenta_iou_mean = np.average(placenta_iou)
placenta_iou_std = np.std(placenta_iou)
brain_iou_mean = np.average(brain_iou)
brain_iou_std = np.std(brain_iou)
placenta_assd_mean = np.average(placenta_assd)
placenta_assd_std = np.std(placenta_assd)
brain_assd_mean = np.average(brain_assd)
brain_assd_std = np.std(brain_assd)
print('The Placenta mean Accuracy: {placenta_dice_mean: .4f}; The Placenta Accuracy std: {placenta_dice_std: .4f}; '
'The Brain mean Accuracy: {brain_dice_mean: .4f}; The Brain Accuracy std: {brain_dice_std: .4f}'.format(
placenta_dice_mean=placenta_dice_mean, placenta_dice_std=placenta_dice_std,
brain_dice_mean=brain_dice_mean, brain_dice_std=brain_dice_std))
print('The Placenta mean IoU: {placenta_iou_mean: .4f}; The Placenta IoU std: {placenta_iou_std: .4f}; '
'The Brain mean IoU: {brain_iou_mean: .4f}; The Brain IoU std: {brain_iou_std: .4f}'.format(
placenta_iou_mean=placenta_iou_mean, placenta_iou_std=placenta_iou_std,
brain_iou_mean=brain_iou_mean, brain_iou_std=brain_iou_std))
print('The Placenta mean assd: {placenta_asd_mean: .4f}; The Placenta assd std: {placenta_asd_std: .4f}; '
'The Brain mean assd: {brain_asd_mean: .4f}; The Brain assd std: {brain_asd_std: .4f}'.format(
placenta_asd_mean=placenta_assd_mean, placenta_asd_std=placenta_assd_std,
brain_asd_mean=brain_assd_mean, brain_asd_std=brain_assd_std))
def test_isic(test_loader, model, args):
isic_dice = []
isic_iou = []
isic_assd = []
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
model.eval()
for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
image = img.float().cuda()
target = lab.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)
isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice accuracy
isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes) # the iou accuracy
isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1]) # the assd
dice_np = isic_b_dice.data.cpu().numpy()
iou_np = isic_b_iou.data.cpu().numpy()
isic_dice.append(dice_np)
isic_iou.append(iou_np)
isic_assd.append(isic_b_asd)
isic_dice_mean = np.average(isic_dice)
isic_dice_std = np.std(isic_dice)
isic_iou_mean = np.average(isic_iou)
isic_iou_std = np.std(isic_iou)
isic_assd_mean = np.average(isic_assd)
isic_assd_std = np.std(isic_assd)
print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The Placenta Accuracy std: {isic_dice_std: .4f}'.format(
isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(
isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))
def main(args):
minloss = [1.0]
start_epoch = args.start_epoch
# loading the dataset
print('loading the {0},{1},{2} dataset ...'.format('train', 'validation', 'test'))
trainset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='train',
transform=Test_Transform[args.data])
validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
transform=Test_Transform[args.data])
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
transform=Test_Transform[args.data])
trainloader = Data.DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
validloader = Data.DataLoader(dataset=validset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
print('Loading is done\n')
# Define model
if args.data == 'Fetus':
args.num_input = 1
args.num_classes = 3
args.out_size = (256, 256)
elif args.data == 'ISIC2018':
args.num_input = 3
args.num_classes = 2
args.out_size = (224, 300)
model = Test_Model[args.id](args, args.num_input, args.num_classes)
if torch.cuda.is_available():
print('We can use', torch.cuda.device_count(), 'GPUs to train the network')
model = model.cuda()
# model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
# collect the number of parameters in the network
print("------------------------------------------")
print("Network Architecture of Model AttU_Net:")
num_para = 0
for name, param in model.named_parameters():
num_mul = 1
for x in param.size():
num_mul *= x
num_para += num_mul
print(model)
print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
print("------------------------------------------")
# Define optimizers and loss function
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr_rate,
weight_decay=args.weight_decay) # optimize all model parameters
criterion = SoftDiceLoss()
scheduler = StepLR(optimizer, step_size=256, gamma=0.5)
# resume
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
# visualiser
vis = visdom.Visdom(env='CA-net')
print("Start training ...")
for epoch in range(start_epoch + 1, args.epochs + 1):
scheduler.step()
train_avg_loss = train(trainloader, model, criterion, optimizer, args, epoch)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([train_avg_loss]),
win=args.id + args.data,
update='append',
opts=dict(title=args.id+'_'+args.data,
xlabel='Epochs',
ylabel='Train_avg_loss'))
if args.data == 'Fetus':
val_avg_loss, val_placenta_dice, val_brain_dice = valid_fetus(validloader, model, criterion,
optimizer, args, epoch, minloss)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),
win=args.id + args.data + 'valid_avg',
name='loss',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_placenta_dice]),
win=args.id + args.data + 'valid_avg',
name='placenta_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_brain_dice]),
win=args.id + args.data + 'valid_avg',
name='brain_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
elif args.data == 'ISIC2018':
val_avg_loss, val_isic_dice = valid_isic(validloader, model, criterion, optimizer, args, epoch, minloss)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),
win=args.id + args.data + 'valid_avg',
name='loss',
update='append',
opts=dict(title=args.id + '_' + args.data + '_',
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_isic_dice]),
win=args.id + args.data + 'valid_avg',
name='isic_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
# save models
if epoch > args.particular_epoch:
if epoch % args.save_epochs_steps == 0:
filename = args.ckpt + '/' + str(epoch) + '_' + args.data + '_checkpoint.pth.tar'
print('the model will be saved at {}'.format(filename))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, filename)
print('Training Done! Start testing')
if args.data == 'Fetus':
test_fetus(testloader, model, args)
elif args.data == 'ISIC2018':
test_isic(testloader, model, args)
print('Testing Done!')
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
# Model related arguments
parser.add_argument('--id', default='Comp_Atten_Unet',
help='a name for identitying the model. Choose from the following options: Unet')
# Path related arguments
parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',
help='root directory of data')
parser.add_argument('--ckpt', default='./saved_models',
help='folder to output checkpoints')
# optimization related arguments
parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--start_epoch', default=0, type=int,
help='epoch to start training. useful if continue from a checkpoint')
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
help='input batch size for training (default: 12)')
parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
parser.add_argument('--particular_epoch', default=30, type=int,
help='after this number, we will save models more frequently')
parser.add_argument('--save_epochs_steps', default=200, type=int,
help='frequency to save models after a particular number of epochs')
parser.add_argument('--resume', default='',
help='the checkpoint that resumes from')
# other arguments
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 300), help='the output image size')
parser.add_argument('--val_folder', default='folder0', type=str,
help='which cross validation folder')
args = parser.parse_args()
print("Input arguments:")
for key, value in vars(args).items():
print("{:16} {}".format(key, value))
args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)
print('Models are saved at %s' % (args.ckpt))
if not os.path.isdir(args.ckpt):
os.makedirs(args.ckpt)
if args.start_epoch > 1:
args.resume = args.ckpt + '/' + str(args.start_epoch) + '_' + args.data + '_checkpoint.pth.tar'
main(args)
================================================
FILE: show_fused_heatmap.py
================================================
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def map_scalar_to_color(x):
x_list = [0.0, 0.25, 0.5, 0.75, 1.0]
c_list = [[0, 0, 255],
[0, 255, 255],
[0, 255, 0],
[255, 255, 0],
[255, 0, 0]]
for i in range(len(x_list)):
if(x <= x_list[i + 1]):
x0 = x_list[i]
x1 = x_list[i + 1]
c0 = c_list[i]
c1 = c_list[i + 1]
alpha = (x - x0)/(x1 - x0)
c = [c0[j]*(1 - alpha) + c1[j] * alpha for j in range(3)]
c = [int(item) for item in c]
return tuple(c)
def get_fused_heat_map(image, att):
[H, W] = image.size
img = Image.new('RGB', image.size, (255, 0, 0))
for i in range(H):
for j in range(W):
p0 = image.getpixel((i,j))
alpha = att.getpixel((i,j))
p1 = map_scalar_to_color(alpha)
alpha = 0.3 + alpha*0.5
p = [int(p0[c] * (1 - alpha) + p1[c]*alpha) for c in range(3)]
p = tuple(p)
img.putpixel((i, j), p)
return img
if __name__ == "__main__":
image_name = "./result/atten_map/ISIC_0015937.jpg"
scalar_name = "./result/atten_map/25_2_8_wgt"
save_name = "./result/atten_map/15937_wgt3_fused"
img = Image.open(image_name)
# img = np.load(image_name)
# img = Image.fromarray(np.uint8(img*255))
# load the scalar map, and normalize the inteinsty to 0 - 1
scl = Image.open(scalar_name).convert('L')
scl = np.asarray(scl)
scl = cv2.resize(scl, dsize=(img.size[0], img.size[1]), interpolation=cv2.INTER_NEAREST)
scl_norm = np.asarray(scl, np.float32)/255
scl_norm = Image.fromarray(scl_norm)
# convert the scalar map to heat map, and fuse it with the original image
img_scl = get_fused_heat_map(img, scl_norm)
# img_scl.save(save_name, format='png')
plt.imshow(img_scl), plt.title('fused result')
# plt.colorbar()
plt.show()
================================================
FILE: utils/binary.py
================================================
# Copyright (C) 2013 Oskar Maier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# author Oskar Maier
# version r0.1.1
# since 2014-03-13
# status Release
# build-in modules
# third-party modules
import numpy
from scipy.ndimage import _ni_support
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
generate_binary_structure
from scipy.ndimage.measurements import label, find_objects
from scipy.stats import pearsonr
# own modules
# code
def dc(result, reference):
r"""
Dice coefficient
Computes the Dice coefficient (also known as Sorensen index) between the binary
objects in two images.
The metric is defined as
.. math::
DC=\frac{2|A\cap B|}{|A|+|B|}
, where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects).
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
dc : float
The Dice coefficient between the object(s) in ```result``` and the
object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap).
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
intersection = numpy.count_nonzero(result & reference)
size_i1 = numpy.count_nonzero(result)
size_i2 = numpy.count_nonzero(reference)
try:
dc = 2. * intersection / float(size_i1 + size_i2)
except ZeroDivisionError:
dc = 0.0
return dc
def jc(result, reference):
"""
Jaccard coefficient
Computes the Jaccard coefficient between the binary objects in two images.
Parameters
----------
result: array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference: array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
jc: float
The Jaccard coefficient between the object(s) in `result` and the
object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap).
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
intersection = numpy.count_nonzero(result & reference)
union = numpy.count_nonzero(result | reference)
jc = float(intersection) / float(union)
return jc
def precision(result, reference):
"""
Precison.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
precision : float
The precision between two binary datasets, here mostly binary objects in images,
which is defined as the fraction of retrieved instances that are relevant. The
precision is not symmetric.
See also
--------
:func:`recall`
Notes
-----
Not symmetric. The inverse of the precision is :func:`recall`.
High precision means that an algorithm returned substantially more relevant results than irrelevant.
References
----------
.. [1] http://en.wikipedia.org/wiki/Precision_and_recall
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tp = numpy.count_nonzero(result & reference)
fp = numpy.count_nonzero(result & ~reference)
try:
precision = tp / float(tp + fp)
except ZeroDivisionError:
precision = 0.0
return precision
def recall(result, reference):
"""
Recall.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
recall : float
The recall between two binary datasets, here mostly binary objects in images,
which is defined as the fraction of relevant instances that are retrieved. The
recall is not symmetric.
See also
--------
:func:`precision`
Notes
-----
Not symmetric. The inverse of the recall is :func:`precision`.
High recall means that an algorithm returned most of the relevant results.
References
----------
.. [1] http://en.wikipedia.org/wiki/Precision_and_recall
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tp = numpy.count_nonzero(result & reference)
fn = numpy.count_nonzero(~result & reference)
try:
recall = tp / float(tp + fn)
except ZeroDivisionError:
recall = 0.0
return recall
def sensitivity(result, reference):
"""
Sensitivity.
Same as :func:`recall`, see there for a detailed description.
See also
--------
:func:`specificity`
"""
return recall(result, reference)
def specificity(result, reference):
"""
Specificity.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
specificity : float
The specificity between two binary datasets, here mostly binary objects in images,
which denotes the fraction of correctly returned negatives. The
specificity is not symmetric.
See also
--------
:func:`sensitivity`
Notes
-----
Not symmetric. The completment of the specificity is :func:`sensitivity`.
High recall means that an algorithm returned most of the irrelevant results.
References
----------
.. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tn = numpy.count_nonzero(~result & ~reference)
fp = numpy.count_nonzero(result & ~reference)
try:
specificity = tn / float(tn + fp)
except ZeroDivisionError:
specificity = 0.0
return specificity
def true_negative_rate(result, reference):
"""
True negative rate.
Same as :func:`specificity`, see there for a detailed description.
See also
--------
:func:`true_positive_rate`
:func:`positive_predictive_value`
"""
return specificity(result, reference)
def true_positive_rate(result, reference):
"""
True positive rate.
Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description.
See also
--------
:func:`positive_predictive_value`
:func:`true_negative_rate`
"""
return recall(result, reference)
def positive_predictive_value(result, reference):
"""
Positive predictive value.
Same as :func:`precision`, see there for a detailed description.
See also
--------
:func:`true_positive_rate`
:func:`true_negative_rate`
"""
return precision(result, reference)
def hd(result, reference, voxelspacing=None, connectivity=1):
"""
Hausdorff Distance.
Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two
images. It is defined as the maximum surface distance between the objects.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
Note that the connectivity influences the result in the case of the Hausdorff distance.
Returns
-------
hd : float
The symmetric Hausdorff Distance between the object(s) in ```result``` and the
object(s) in ```reference```. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`assd`
:func:`asd`
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()
hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()
hd = max(hd1, hd2)
return hd
def hd95(result, reference, voxelspacing=None, connectivity=1):
"""
95th percentile of the Hausdorff Distance.
Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two
images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is
commonly used in Biomedical Segmentation challenges.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
Note that the connectivity influences the result in the case of the Hausdorff distance.
Returns
-------
hd : float
The symmetric Hausdorff Distance between the object(s) in ```result``` and the
object(s) in ```reference```. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`hd`
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
hd1 = __surface_distances(result, reference, voxelspacing, connectivity)
hd2 = __surface_distances(reference, result, voxelspacing, connectivity)
hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95)
return hd95
def assd(result, reference, voxelspacing=None, connectivity=1):
"""
Average symmetric surface distance.
Computes the average symmetric surface distance (ASD) between the binary objects in
two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
assd : float
The average symmetric surface distance between the object(s) in ``result`` and the
object(s) in ``reference``. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`asd`
:func:`hd`
Notes
-----
This is a real metric, obtained by calling and averaging
>>> asd(result, reference)
and
>>> asd(reference, result)
The binary images can therefore be supplied in any order.
"""
assd = numpy.mean( (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)) )
return assd
def asd(result, reference, voxelspacing=None, connectivity=1):
"""
Average surface distance metric.
Computes the average surface distance (ASD) between the binary objects in two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
asd : float
The average surface distance between the object(s) in ``result`` and the
object(s) in ``reference``. The distance unit is the same as for the spacing
of elements along each dimension, which is usually given in mm.
See also
--------
:func:`assd`
:func:`hd`
Notes
-----
This is not a real metric, as it is directed. See `assd` for a real metric of this.
The method is implemented making use of distance images and simple binary morphology
to achieve high computational speed.
Examples
--------
The `connectivity` determines what pixels/voxels are considered the surface of a
binary object. Take the following binary image showing a cross
>>> from scipy.ndimage.morphology import generate_binary_structure
>>> cross = generate_binary_structure(2, 1)
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
With `connectivity` set to `1` a 4-neighbourhood is considered when determining the
object surface, resulting in the surface
.. code-block:: python
array([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get:
.. code-block:: python
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
, as a diagonal connection does no longer qualifies as valid object surface.
This influences the results `asd` returns. Imagine we want to compute the surface
distance of our cross to a cube-like object:
>>> cube = generate_binary_structure(2, 1)
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
, which surface is, independent of the `connectivity` value set, always
.. code-block:: python
array([[1, 1, 1],
[1, 0, 1],
[1, 1, 1]])
Using a `connectivity` of `1` we get
>>> asd(cross, cube, connectivity=1)
0.0
while a value of `2` returns us
>>> asd(cross, cube, connectivity=2)
0.20000000000000001
due to the center of the cross being considered surface as well.
"""
sds = __surface_distances(result, reference, voxelspacing, connectivity)
asd = sds.mean()
return asd
def ravd(result, reference):
"""
Relative absolute volume difference.
Compute the relative absolute volume difference between the (joined) binary objects
in the two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
ravd : float
The relative absolute volume difference between the object(s) in ``result``
and the object(s) in ``reference``. This is a percentage value in the range
:math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score.
Raises
------
RuntimeError
If the reference object is empty.
See also
--------
:func:`dc`
:func:`precision`
:func:`recall`
Notes
-----
This is not a real metric, as it is directed. Negative values denote a smaller
and positive values a larger volume than the reference.
This implementation does not check, whether the two supplied arrays are of the same
size.
Examples
--------
Considering the following inputs
>>> import numpy
>>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]])
>>> arr1
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
>>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]])
>>> arr2
array([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
comparing `arr1` to `arr2` we get
>>> ravd(arr1, arr2)
-0.2
and reversing the inputs the directivness of the metric becomes evident
>>> ravd(arr2, arr1)
0.25
It is important to keep in mind that a perfect score of `0` does not mean that the
binary objects fit exactely, as only the volumes are compared:
>>> arr1 = numpy.asarray([1,0,0])
>>> arr2 = numpy.asarray([0,0,1])
>>> ravd(arr1, arr2)
0.0
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
vol1 = numpy.count_nonzero(result)
vol2 = numpy.count_nonzero(reference)
if 0 == vol2:
raise RuntimeError('The second supplied array does not contain any binary object.')
return (vol1 - vol2) / float(vol2)
def volume_correlation(results, references):
r"""
Volume correlation.
Computes the linear correlation in binary object volume between the
contents of the successive binary images supplied. Measured through
the Pearson product-moment correlation coefficient.
Parameters
----------
results : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
references : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
The order must be the same as for ``results``.
Returns
-------
r : float
The correlation coefficient between -1 and 1.
p : float
The two-side p value.
"""
results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
results_volumes = [numpy.count_nonzero(r) for r in results]
references_volumes = [numpy.count_nonzero(r) for r in references]
return pearsonr(results_volumes, references_volumes) # returns (Pearson'
def volume_change_correlation(results, references):
r"""
Volume change correlation.
Computes the linear correlation of change in binary object volume between
the contents of the successive binary images supplied. Measured through
the Pearson product-moment correlation coefficient.
Parameters
----------
results : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
references : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
The order must be the same as for ``results``.
Returns
-------
r : float
The correlation coefficient between -1 and 1.
p : float
The two-side p value.
"""
results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results])
references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references])
results_volumes_changes = results_volumes[1:] - results_volumes[:-1]
references_volumes_changes = references_volumes[1:] - references_volumes[:-1]
return pearsonr(results_volumes_changes, references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value)
def obj_assd(result, reference, voxelspacing=None, connectivity=1):
"""
Average symmetric surface distance.
Computes the average symmetric surface distance (ASSD) between the binary objects in
two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object as well as when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
assd : float
The average symmetric surface distance between all mutually existing distinct
binary object(s) in ``result`` and ``reference``. The distance unit is the same as for
the spacing of elements along each dimension, which is usually given in mm.
See also
--------
:func:`obj_asd`
Notes
-----
This is a real metric, obtained by calling and averaging
>>> obj_asd(result, reference)
and
>>> obj_asd(reference, result)
The binary images can therefore be supplied in any order.
"""
assd = numpy.mean( (obj_asd(result, reference, voxelspacing, connectivity), obj_asd(reference, result, voxelspacing, connectivity)) )
return assd
def obj_asd(result, reference, voxelspacing=None, connectivity=1):
"""
Average surface distance between objects.
First correspondences between distinct binary objects in reference and result are
established. Then the average surface distance is only computed between corresponding
objects. Correspondence is defined as unique and at least one voxel overlap.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object as well as when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
asd : float
The average surface distance between all mutually existing distinct binary
object(s) in ``result`` and ``reference``. The distance unit is the same as for the
spacing of elements along each dimension, which is usually given in mm.
See also
--------
:func:`obj_assd`
:func:`obj_tpr`
:func:`obj_fpr`
Notes
-----
This is not a real metric, as it is directed. See `obj_assd` for a real metric of this.
For the understanding of this metric, both the notions of connectedness and surface
distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more
information on the first and :func:`asd` on the second.
Examples
--------
>>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]])
>>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]])
>>> arr1
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
>>> arr2
array([[0, 1, 0],
[0, 1, 0],
[0, 1, 0]])
>>> obj_asd(arr1, arr2)
1.5
>>> obj_asd(arr2, arr1)
0.333333333333
With the `voxelspacing` parameter, the distances between the voxels can be set for
each dimension separately:
>>> obj_asd(arr1, arr2, voxelspacing=(1,2))
1.5
>>> obj_asd(arr2, arr1, voxelspacing=(1,2))
0.333333333333
More examples depicting the notion of object connectedness:
>>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]])
>>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
>>> arr1
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 0]])
>>> arr2
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 1]])
>>> obj_asd(arr1, arr2)
0.0
>>> obj_asd(arr2, arr1)
0.0
>>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]])
>>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
>>> arr1
array([[1, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> arr2
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 1]])
>>> obj_asd(arr1, arr2)
0.6
>>> obj_asd(arr2, arr1)
0.0
Influence of `connectivity` parameter can be seen in the following example, where
with the (default) connectivity of `1` the first array is considered to contain two
objects, while with an increase connectivity of `2`, just one large object is
detected.
>>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]])
>>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]])
>>> arr1
array([[1, 0, 0],
[0, 1, 1],
[0, 1, 1]])
>>> arr2
array([[1, 0, 0],
[0, 0, 0],
[0, 0, 0]])
>>> obj_asd(arr1, arr2)
0.0
>>> obj_asd(arr1, arr2, connectivity=2)
1.742955328
Note that the connectivity also influence the notion of what is considered an object
surface voxels.
"""
sds = list()
labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity)
slicers1 = find_objects(labelmap1)
slicers2 = find_objects(labelmap2)
for lid2, lid1 in list(mapping.items()):
window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1])
object1 = labelmap1[window] == lid1
object2 = labelmap2[window] == lid2
sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity))
asd = numpy.mean(sds)
return asd
def obj_fpr(result, reference, connectivity=1):
"""
The false positive rate of distinct binary object detection.
The false positive rates gives a percentage measure of how many distinct binary
objects in the second array do not exists in the first array. A partial overlap
(of minimum one voxel) is here considered sufficient.
In cases where two distinct binary object in the second array overlap with a single
distinct object in the first array, only one is considered to have been detected
successfully and the other is added to the count of false positives.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
tpr : float
A percentage measure of how many distinct binary objects in ``results`` have no
corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0`
denotes an ideal score.
Raises
------
RuntimeError
If the second array is empty.
See also
--------
:func:`obj_tpr`
Notes
-----
This is not a real metric, as it is directed. Whatever array is considered as
reference should be passed second. A perfect score of :math:`0` tells that there are no
distinct binary objects in the second array that do not exists also in the reference
array, but does not reveal anything about objects in the reference array also
existing in the second array (use :func:`obj_tpr` for this).
Examples
--------
>>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
>>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
>>> arr2
array([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> arr1
array([[0, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.0
Example of directedness:
>>> arr2 = numpy.asarray([1,0,1,0,1])
>>> arr1 = numpy.asarray([1,0,1,0,0])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.3333333333333333
Examples of multiple overlap treatment:
>>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
>>> obj_fpr(arr1, arr2)
0.3333333333333333
>>> obj_fpr(arr2, arr1)
0.3333333333333333
>>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.3333333333333333
>>> arr2 = numpy.asarray([[1,0,1,0,0],
[1,0,0,0,0],
[1,0,1,1,1],
[0,0,0,0,0],
[1,0,1,0,0]])
>>> arr1 = numpy.asarray([[1,1,1,0,0],
[0,0,0,0,0],
[1,1,1,0,1],
[0,0,0,0,0],
[1,1,1,0,0]])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.2
"""
_, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
return (n_obj_reference - len(mapping)) / float(n_obj_reference)
def obj_tpr(result, reference, connectivity=1):
"""
The true positive rate of distinct binary object detection.
The true positive rates gives a percentage measure of how many distinct binary
objects in the first array also exists in the second array. A partial overlap
(of minimum one voxel) is here considered sufficient.
In cases where two distinct binary object in the first array overlaps with a single
distinct object in the second array, only one is considered to have been detected
successfully.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
tpr : float
A percentage measure of how many distinct binary objects in ``result`` also exists
in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score.
Raises
------
RuntimeError
If the reference object is empty.
See also
--------
:func:`obj_fpr`
Notes
-----
This is not a real metric, as it is directed. Whatever array is considered as
reference should be passed second. A perfect score of :math:`1` tells that all distinct
binary objects in the reference array also exist in the result array, but does not
reveal anything about additional binary objects in the result array
(use :func:`obj_fpr` for this).
Examples
--------
>>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
>>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
>>> arr2
array([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> arr1
array([[0, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> obj_tpr(arr1, arr2)
1.0
>>> obj_tpr(arr2, arr1)
1.0
Example of directedness:
>>> arr2 = numpy.asarray([1,0,1,0,1])
>>> arr1 = numpy.asarray([1,0,1,0,0])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
1.0
Examples of multiple overlap treatment:
>>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
0.6666666666666666
>>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
1.0
>>> arr2 = numpy.asarray([[1,0,1,0,0],
[1,0,0,0,0],
[1,0,1,1,1],
[0,0,0,0,0],
[1,0,1,0,0]])
>>> arr1 = numpy.asarray([[1,1,1,0,0],
[0,0,0,0,0],
[1,1,1,0,1],
[0,0,0,0,0],
[1,1,1,0,0]])
>>> obj_tpr(arr1, arr2)
0.8
>>> obj_tpr(arr2, arr1)
1.0
"""
_, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
return len(mapping) / float(n_obj_result)
def __distinct_binary_object_correspondences(reference, result, connectivity=1):
"""
Determines all distinct (where connectivity is defined by the connectivity parameter
passed to scipy's `generate_binary_structure`) binary objects in both of the input
parameters and returns a 1to1 mapping from the labelled objects in reference to the
corresponding (whereas a one-voxel overlap suffices for correspondence) objects in
result.
All stems from the problem, that the relationship is non-surjective many-to-many.
@return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1)
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
# binary structure
footprint = generate_binary_structure(result.ndim, connectivity)
# label distinct binary objects
labelmap1, n_obj_result = label(result, footprint)
labelmap2, n_obj_reference = label(reference, footprint)
# find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing
slicers = find_objects(labelmap2) # get windows of labelled objects
mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1
used_labels = set() # set to collect all already used labels from labelmap2
one_to_many = list() # list to collect all one-to-many mappings
for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows
l1id += 1 # labelled objects have ids sarting from 1
bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation
l2ids = numpy.unique(labelmap1[slicer][bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping)
l2ids = l2ids[0 != l2ids] # remove background identifiers (=0)
if 1 == len(l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used
l2id = l2ids[0]
if not l2id in used_labels:
mapping[l1id] = l2id
used_labels.add(l2id)
elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing
one_to_many.append((l1id, set(l2ids)))
# process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first
while True:
one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in one_to_many] # remove already used ids from all sets
one_to_many = [x for x in one_to_many if x[1]] # remove empty sets
one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length
if 0 == len(one_to_many):
break
l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set
mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings
used_labels.add(l2id) # mark target label as used
one_to_many = one_to_many[1:] # delete the processed set from all sets
return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping
def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
"""
The distances between the surface voxel of binary objects in result and their
nearest partner surface voxel of a binary object in reference.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
if voxelspacing is not None:
voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64)
if not voxelspacing.flags.contiguous:
voxelspacing = voxelspacing.copy()
# binary structure
footprint = generate_binary_structure(result.ndim, connectivity)
# test for emptiness
if 0 == numpy.count_nonzero(result):
raise RuntimeError('The first supplied array does not contain any binary object.')
if 0 == numpy.count_nonzero(reference):
raise RuntimeError('The second supplied array does not contain any binary object.')
# extract only 1-pixel border line of objects
result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
# compute average surface distance
# Note: scipys distance transform is calculated only inside the borders of the
# foreground objects, therefore the input has to be reversed
dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
sds = dt[result_border]
return sds
def __combine_windows(w1, w2):
"""
Joins two windows (defined by tuple of slices) such that their maximum
combined extend is covered by the new returned window.
"""
res = []
for s1, s2 in zip(w1, w2):
res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop)))
return tuple(res)
================================================
FILE: utils/dice_loss.py
================================================
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
class SoftDiceLoss(_Loss):
'''
Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
eps is a small constant to avoid zero division,
'''
def __init__(self, *args, **kwargs):
super(SoftDiceLoss, self).__init__()
def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
dice_loss = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
return dice_loss
def get_soft_label(input_tensor, num_class):
"""
convert a label tensor to soft label
input_tensor: tensor with shape [N, C, H, W]
output_tensor: shape [N, H, W, num_class]
"""
tensor_list = []
input_tensor = input_tensor.permute(0, 2, 3, 1)
for i in range(num_class):
temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=-1)
output_tensor = output_tensor.float()
return output_tensor
def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
predict = prediction.permute(0, 2, 3, 1)
pred = predict.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
n_voxels = ground.size(0)
if weight_map is not None:
weight_map = weight_map.view(-1)
weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
ref_vol = torch.sum(weight_map_nclass * ground, 0)
intersect = torch.sum(weight_map_nclass * ground * pred, 0)
seg_vol = torch.sum(weight_map_nclass * pred, 0)
else:
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
# dice_loss = 1.0 - torch.mean(dice_score)
# return dice_loss
dice_score = torch.mean(-torch.log(dice_score))
return dice_score
def val_dice_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_mean_score = torch.mean(dice_score)
placenta_dice = dice_score[1]
brain_dice = dice_score[2]
return placenta_dice, brain_dice
def Intersection_over_Union_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
dice_mean_score = torch.mean(iou_score)
placenta_iou = iou_score[1]
brain_iou = iou_score[2]
return placenta_iou, brain_iou
def val_dice_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_mean_score = torch.mean(dice_score)
return dice_mean_score
def Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
iou_mean_score = torch.mean(iou_score)
return iou_mean_score
================================================
FILE: utils/evaluation.py
================================================
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
================================================
FILE: utils/transform.py
================================================
import torch
import random
import PIL
import numbers
import numpy as np
import torch.nn as nn
import collections
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
def ISIC2018_transform(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
image, label = randomcrop(size=(224, 300))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image, label = resize(size=(224, 300))(image, label)
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
# these are founctional function for transform
def randomflip_rotate(img, lab, p=0.5, degrees=0):
if random.random() < p:
img = TF.hflip(img)
lab = TF.hflip(lab)
if random.random() < p:
img = TF.vflip(img)
lab = TF.vflip(lab)
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
degrees = degrees
angle = random.uniform(degrees[0], degrees[1])
img = TF.rotate(img, angle)
lab = TF.rotate(lab, angle)
return img, lab
class randomcrop(object):
"""Crop the given PIL Image and mask at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be cropped.
lab (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image and mask.
"""
if self.padding > 0:
img = TF.pad(img, self.padding)
lab = TF.pad(lab, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size)
return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class resize(object):
"""Resize the input PIL Image and mask to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be scaled.
lab (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image and mask.
"""
return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
def itensity_normalize(volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized n d volume
"""
# pixels = volume[volume > 0]
mean = volume.mean()
std = volume.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
================================================
FILE: validation.py
================================================
import os
import torch
import argparse
import numpy as np
import pandas as pd
import torch.utils.data as Data
from utils.binary import assd
from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform
from Models.networks.network import Comprehensive_Atten_Unet
from utils.dice_loss import get_soft_label, val_dice_isic
from utils.dice_loss import Intersection_over_Union_isic
from time import *
Test_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'ISIC2018': ISIC2018_transform}
def test_isic(test_loader, model):
isic_dice = []
isic_iou = []
isic_assd = []
infer_time = []
model.eval()
for step, (img, lab) in enumerate(test_loader):
image = img.float().cuda()
target = lab.float().cuda()
# output, atten2_map, atten3_map = model(image) # model output
begin_time = time()
output = model(image)
end_time = time()
pred_time = end_time - begin_time
infer_time.append(pred_time)
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
# input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)
label_arr = target_soft.cpu().numpy().astype(np.uint8)
# label_shw = np.squeeze(target.cpu().numpy()).astype(np.uint8)
output_arr = output_soft.cpu().byte().numpy().astype(np.uint8)
isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice accuracy
isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes) # the iou accuracy
isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])
dice_np = isic_b_dice.data.cpu().numpy()
iou_np = isic_b_iou.data.cpu().numpy()
isic_dice.append(dice_np)
isic_iou.append(iou_np)
isic_assd.append(isic_b_asd)
# df = pd.DataFrame(data=dice_np)
# df.to_csv(args.ckpt + '/refine_result.csv')
isic_dice_mean = np.average(isic_dice)
isic_dice_std = np.std(isic_dice)
isic_iou_mean = np.average(isic_iou)
isic_iou_std = np.std(isic_iou)
isic_assd_mean = np.average(isic_assd)
isic_assd_std = np.std(isic_assd)
all_time = np.sum(infer_time)
print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The ISIC Accuracy std: {isic_dice_std: .4f}'.format(
isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(
isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))
print('The inference time: {time: .4f}'.format(time=all_time))
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), 'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(description='U-net add Attention mechanism for biomedical Dataset')
# Model related arguments
parser.add_argument('--id', default='Comp_Atten_Unet',
help='a name for identitying the model. Choose from the following options: Unet_fetus')
# Path related arguments
parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',
help='root directory of data')
parser.add_argument('--ckpt', default='./saved_models',
help='folder to output checkpoints')
parser.add_argument('--save', default='./result',
help='folder to outoput result')
parser.add_argument('--batch_size', type=int, default=1, metavar='N',
help='input batch size for training (default: 16)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--epoch', type=int, default=300, metavar='N',
help='choose the specific epoch checkpoints')
# other arguments
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 300), help='the output image size')
parser.add_argument('--att_pos', default='dec', type=str,
help='where attention to plug in (enc, dec, enc\&dec)')
parser.add_argument('--view', default='axial', type=str,
help='use what views data to test (for fetal MRI)')
parser.add_argument('--val_folder', default='folder0', type=str,
help='which cross validation folder')
args = parser.parse_args()
args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)
# loading the dataset
print('loading the {0} dataset ...'.format('test'))
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test', transform=Test_Transform[args.data])
testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False)
print('Loading is done\n')
# Define model
if torch.cuda.is_available():
print('We can use', torch.cuda.device_count(), 'GPUs to train the network')
if args.data == 'Fetus':
args.num_input = 1
args.num_classes = 3
args.out_size = (256, 256)
elif args.data == 'ISIC2018':
args.num_input = 3
args.num_classes = 2
args.out_size = (224, 300)
model = Test_Model[args.id](args, args.num_input, args.num_classes).cuda()
# model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
# Load the trained best model
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
# multi-GPU transfer to one GPU
# model_dict = model.state_dict()
# pretrained_dict = checkpoint['state_dict']
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in pretrained_dict.items():
# name = k[7:]
# new_state_dict[name] = v
#
# model_dict.update(new_state_dict)
# model.load_state_dict(model_dict)
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
test_isic(testloader, model)