Repository: HiLab-git/CA-Net
Branch: master
Commit: 94f2624ee634
Files: 30
Total size: 184.9 KB
Directory structure:
gitextract_zlon6qmm/
├── .idea/
│ ├── CA-Net.iml
│ ├── encodings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── vcs.xml
│ └── workspace.xml
├── Datasets/
│ ├── ISIC2018.py
│ └── folder0/
│ └── folder0_test.list
├── Models/
│ ├── __init__.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── channel_attention_layer.py
│ │ ├── grid_attention_layer.py
│ │ ├── modules.py
│ │ ├── nonlocal_layer.py
│ │ └── scale_attention_layer.py
│ ├── networks/
│ │ └── network.py
│ └── networks_other.py
├── README.md
├── create_folder.py
├── data/
│ └── ISIC2018_Task1_npy_all/
│ ├── image/
│ │ └── ISIC_0010854.npy
│ └── label/
│ └── ISIC_0010854_segmentation.npy
├── isic_preprocess.py
├── main.py
├── result/
│ └── atten_map/
│ └── 25_2_8_wgt
├── show_fused_heatmap.py
├── utils/
│ ├── binary.py
│ ├── dice_loss.py
│ ├── evaluation.py
│ └── transform.py
└── validation.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .idea/CA-Net.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>
================================================
FILE: .idea/encodings.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding" addBOMForNewFiles="with NO BOM" />
</project>
================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (pytorch)" project-jdk-type="Python SDK" />
</project>
================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/CA-Net.iml" filepath="$PROJECT_DIR$/.idea/CA-Net.iml" />
</modules>
</component>
</project>
================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
================================================
FILE: .idea/workspace.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="2541e3bb-fbe2-4fc6-be8f-b6401cb16713" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/CA-Net.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/CA-Net.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/create_folder.py" beforeDir="false" afterPath="$PROJECT_DIR$/create_folder.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/image/ISIC_0010854.npy" beforeDir="false" afterPath="$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/image/ISIC_0010854.npy" afterDir="false" />
<change beforePath="$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/label/ISIC_0010854_segmentation.npy" beforeDir="false" afterPath="$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/label/ISIC_0010854_segmentation.npy" afterDir="false" />
<change beforePath="$PROJECT_DIR$/isic_preprocess.py" beforeDir="false" afterPath="$PROJECT_DIR$/isic_preprocess.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils/transform.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils/transform.py" afterDir="false" />
</list>
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="CoverageDataManager">
<SUITE FILE_PATH="coverage/CA_Net$validation.coverage" NAME="validation Coverage Results" MODIFIED="1598537010616" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
<SUITE FILE_PATH="coverage/CA_Net$isic_preprocess.coverage" NAME="isic_preprocess Coverage Results" MODIFIED="1598536798821" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
</component>
<component name="FileEditorManager">
<leaf SIDE_TABS_SIZE_LIMIT_KEY="300">
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/create_folder.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="374">
<caret line="17" column="46" lean-forward="true" selection-start-line="17" selection-start-column="46" selection-end-line="17" selection-end-column="46" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/README.md">
<provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
<state split_layout="SPLIT">
<first_editor relative-caret-position="1254">
<caret line="57" selection-start-line="57" selection-end-line="57" />
</first_editor>
<second_editor />
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/validation.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="743">
<caret line="118" selection-start-line="118" selection-end-line="118" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/isic_preprocess.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="550">
<caret line="25" column="38" lean-forward="true" selection-start-line="25" selection-start-column="38" selection-end-line="25" selection-end-column="38" />
<folding>
<element signature="e#143#152#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/utils/transform.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="479">
<caret line="85" column="17" lean-forward="true" selection-start-line="85" selection-start-column="17" selection-end-line="85" selection-end-column="17" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/Models/networks/network.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="374">
<caret line="17" column="36" selection-start-line="17" selection-start-column="36" selection-end-line="17" selection-end-column="36" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/Models/layers/nonlocal_layer.py">
<provider selected="true" editor-type-id="text-editor">
<state>
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/Models/layers/scale_attention_layer.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="2640">
<caret line="120" column="25" selection-start-line="120" selection-start-column="25" selection-end-line="120" selection-end-column="25" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/main.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="545">
<caret line="428" column="36" selection-start-line="428" selection-start-column="36" selection-end-line="428" selection-end-column="36" />
</state>
</provider>
</entry>
</file>
</leaf>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="IdeDocumentHistory">
<option name="CHANGED_PATHS">
<list>
<option value="$PROJECT_DIR$/Datasets/ISIC2018.py" />
<option value="$PROJECT_DIR$/main.py" />
<option value="$PROJECT_DIR$/README.md" />
<option value="$PROJECT_DIR$/validation.py" />
<option value="$PROJECT_DIR$/Models/networks/network.py" />
<option value="$PROJECT_DIR$/utils/transform.py" />
<option value="$PROJECT_DIR$/isic_preprocess.py" />
<option value="$PROJECT_DIR$/create_folder.py" />
</list>
</option>
</component>
<component name="ProjectFrameBounds" extendedState="6">
<option name="x" value="65" />
<option name="y" value="-4" />
<option name="width" value="1855" />
<option name="height" value="1084" />
</component>
<component name="ProjectView">
<navigator proportions="" version="1">
<foldersAlwaysOnTop value="true" />
</navigator>
<panes>
<pane id="Scope" />
<pane id="ProjectPane">
<subPane>
<expand>
<path>
<item name="CA-Net" type="b2602c69:ProjectViewProjectNode" />
<item name="CA-Net" type="462c0819:PsiDirectoryNode" />
</path>
</expand>
<select />
</subPane>
</pane>
</panes>
</component>
<component name="PropertiesComponent">
<property name="WebServerToolWindowFactoryState" value="false" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
<property name="nodejs_npm_path_reset_for_default_project" value="true" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/data/11/ISIC2018_Task1-2_Training_Input" />
<recent name="$PROJECT_DIR$/data/11/ISIC2018_Task1_Training_GroundTruth" />
<recent name="$PROJECT_DIR$/Datasets" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/data/11/ISIC2018_Task1_Training_GroundTruth" />
<recent name="$PROJECT_DIR$/Datasets" />
</key>
</component>
<component name="RunDashboard">
<option name="ruleStates">
<list>
<RuleState>
<option name="name" value="ConfigurationTypeDashboardGroupingRule" />
</RuleState>
<RuleState>
<option name="name" value="StatusDashboardGroupingRule" />
</RuleState>
</list>
</option>
</component>
<component name="RunManager">
<configuration name="validation" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="CA-Net" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/validation.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.validation" />
</list>
</recent_temporary>
</component>
<component name="SvnConfiguration">
<configuration />
</component>
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="2541e3bb-fbe2-4fc6-be8f-b6401cb16713" name="Default Changelist" comment="" />
<created>1598531699004</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1598531699004</updated>
<workItem from="1598531700459" duration="3016000" />
<workItem from="1598535247980" duration="2358000" />
<workItem from="1598537682047" duration="117000" />
<workItem from="1598538259149" duration="289000" />
</task>
<servers />
</component>
<component name="TimeTrackingManager">
<option name="totallyTimeSpent" value="5780000" />
</component>
<component name="ToolWindowManager">
<frame x="65" y="-4" width="1855" height="1084" extended-state="6" />
<editor active="true" />
<layout>
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.17819591" />
<window_info id="Structure" order="1" side_tool="true" weight="0.25" />
<window_info id="Favorites" order="2" side_tool="true" />
<window_info anchor="bottom" id="Message" order="0" />
<window_info anchor="bottom" id="Find" order="1" />
<window_info anchor="bottom" id="Run" order="2" />
<window_info anchor="bottom" id="Debug" order="3" weight="0.2761506" />
<window_info anchor="bottom" id="Cvs" order="4" weight="0.25" />
<window_info anchor="bottom" id="Inspection" order="5" weight="0.4" />
<window_info anchor="bottom" id="TODO" order="6" />
<window_info anchor="bottom" id="Docker" order="7" show_stripe_button="false" />
<window_info anchor="bottom" id="Version Control" order="8" />
<window_info anchor="bottom" id="Database Changes" order="9" />
<window_info anchor="bottom" id="Event Log" order="10" side_tool="true" />
<window_info anchor="bottom" id="Terminal" order="11" visible="true" weight="0.3294979" />
<window_info anchor="bottom" id="Python Console" order="12" />
<window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" />
<window_info anchor="right" id="Ant Build" order="1" weight="0.25" />
<window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" />
<window_info anchor="right" id="SciView" order="3" visible="true" weight="0.12562259" />
<window_info anchor="right" id="Database" order="4" />
</layout>
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="1" />
</component>
<component name="editorHistoryManager">
<entry file="file://$PROJECT_DIR$/Datasets/folder1/folder1_test.list" />
<entry file="file://$USER_HOME$/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/lib/npyio.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="358">
<caret line="293" lean-forward="true" selection-start-line="293" selection-end-line="293" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/Datasets/ISIC2018.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="286">
<caret line="13" selection-start-line="13" selection-end-line="13" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/README.md">
<provider selected="true" editor-type-id="split-provider[text-editor;markdown-preview-editor]">
<state split_layout="SPLIT">
<first_editor relative-caret-position="1254">
<caret line="57" selection-start-line="57" selection-end-line="57" />
</first_editor>
<second_editor />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/Models/networks/network.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="374">
<caret line="17" column="36" selection-start-line="17" selection-start-column="36" selection-end-line="17" selection-end-column="36" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/Models/layers/nonlocal_layer.py">
<provider selected="true" editor-type-id="text-editor">
<state>
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/Models/layers/scale_attention_layer.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="2640">
<caret line="120" column="25" selection-start-line="120" selection-start-column="25" selection-end-line="120" selection-end-column="25" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/Datasets/folder0/folder0_test.list">
<provider selected="true" editor-type-id="text-editor">
<state>
<caret column="16" lean-forward="true" selection-start-column="16" selection-end-column="16" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/utils/transform.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="479">
<caret line="85" column="17" lean-forward="true" selection-start-line="85" selection-start-column="17" selection-end-line="85" selection-end-column="17" />
<folding>
<element signature="e#0#12#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/validation.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="743">
<caret line="118" selection-start-line="118" selection-end-line="118" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/main.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="545">
<caret line="428" column="36" selection-start-line="428" selection-start-column="36" selection-end-line="428" selection-end-column="36" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/create_folder.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="374">
<caret line="17" column="46" lean-forward="true" selection-start-line="17" selection-start-column="46" selection-end-line="17" selection-end-column="46" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/isic_preprocess.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="550">
<caret line="25" column="38" lean-forward="true" selection-start-line="25" selection-start-column="38" selection-end-line="25" selection-end-column="38" />
<folding>
<element signature="e#143#152#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</component>
</project>
================================================
FILE: Datasets/ISIC2018.py
================================================
import os
import PIL
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from os import listdir
from os.path import join
from PIL import Image
from utils.transform import itensity_normalize
from torch.utils.data.dataset import Dataset
class ISIC2018_dataset(Dataset):
def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
folder='folder0', train_type='train', transform=None):
self.transform = transform
self.train_type = train_type
self.folder_file = './Datasets/' + folder
if self.train_type in ['train', 'validation', 'test']:
# this is for cross validation
with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),
'r') as f:
self.image_list = f.readlines()
self.image_list = [item.replace('\n', '') for item in self.image_list]
self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]
self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]
# self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in
# listdir(join(dataset_folder, self.train_type, 'image'))])
# self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in
# listdir(join(dataset_folder, self.train_type, 'label'))])
else:
print("Choosing type error, You have to choose the loading data type including: train, validation, test")
assert len(self.folder) == len(self.mask)
def __getitem__(self, item: int):
image = np.load(self.folder[item])
label = np.load(self.mask[item])
sample = {'image': image, 'label': label}
if self.transform is not None:
# TODO: transformation to argument datasets
sample = self.transform(sample, self.train_type)
return sample['image'], sample['label']
def __len__(self):
return len(self.folder)
# a = ISIC2018_dataset()
================================================
FILE: Datasets/folder0/folder0_test.list
================================================
ISIC_0010854.npy
================================================
FILE: Models/__init__.py
================================================
================================================
FILE: Models/layers/__init__.py
================================================
================================================
FILE: Models/layers/channel_attention_layer.py
================================================
import torch.nn as nn
# # SE block add to U-net
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
class SE_Conv_Block(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False):
super(SE_Conv_Block, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes * 2)
self.bn2 = nn.BatchNorm2d(planes * 2)
self.conv3 = conv3x3(planes * 2, planes)
self.bn3 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.dropout = drop_out
if planes <= 16:
self.globalAvgPool = nn.AvgPool2d((224, 300), stride=1) # (224, 300) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((224, 300), stride=1)
elif planes == 32:
self.globalAvgPool = nn.AvgPool2d((112, 150), stride=1) # (112, 150) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((112, 150), stride=1)
elif planes == 64:
self.globalAvgPool = nn.AvgPool2d((56, 75), stride=1) # (56, 75) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((56, 75), stride=1)
elif planes == 128:
self.globalAvgPool = nn.AvgPool2d((28, 37), stride=1) # (28, 37) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((28, 37), stride=1)
elif planes == 256:
self.globalAvgPool = nn.AvgPool2d((14, 18), stride=1) # (14, 18) for ISIC2018
self.globalMaxPool = nn.MaxPool2d((14, 18), stride=1)
self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2))
self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2)
self.sigmoid = nn.Sigmoid()
self.downchannel = None
if inplanes != planes:
self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * 2),)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downchannel is not None:
residual = self.downchannel(x)
original_out = out
out1 = out
# For global average pool
out = self.globalAvgPool(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
out = out.view(out.size(0), out.size(1), 1, 1)
avg_att = out
out = out * original_out
# For global maximum pool
out1 = self.globalMaxPool(out1)
out1 = out1.view(out1.size(0), -1)
out1 = self.fc1(out1)
out1 = self.relu(out1)
out1 = self.fc2(out1)
out1 = self.sigmoid(out1)
out1 = out1.view(out1.size(0), out1.size(1), 1, 1)
max_att = out1
out1 = out1 * original_out
att_weight = avg_att + max_att
out += out1
out += residual
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.relu(out)
if self.dropout:
out = nn.Dropout2d(0.5)(out)
return out, att_weight
================================================
FILE: Models/layers/grid_attention_layer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from Models.networks_other import init_weights
class _GridAttentionBlockND(nn.Module):
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
sub_sample_factor=(2,2,2)):
super(_GridAttentionBlockND, self).__init__()
assert dimension in [2, 3]
assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual']
# Downsampling rate for the input featuremap
if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension
# Default parameter set
self.mode = mode
self.dimension = dimension
self.sub_sample_kernel_size = self.sub_sample_factor
# Number of channels (pixel dimensions)
self.in_channels = in_channels
self.gating_channels = gating_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
bn = nn.BatchNorm3d
self.upsample_mode = 'trilinear'
elif dimension == 2:
conv_nd = nn.Conv2d
bn = nn.BatchNorm2d
self.upsample_mode = 'bilinear'
else:
raise NotImplemented
# Output transform
self.W = nn.Sequential(
conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
bn(self.in_channels),
)
# Theta^T * x_ij + Phi^T * gating_signal + bias
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True)
self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
kernel_size=(1, 1), stride=1, padding=0, bias=True)
self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
# Define the operation
if mode == 'concatenation':
self.operation_function = self._concatenation
elif mode == 'concatenation_debug':
self.operation_function = self._concatenation_debug
elif mode == 'concatenation_residual':
self.operation_function = self._concatenation_residual
else:
raise NotImplementedError('Unknown operation function.')
def forward(self, x, g):
'''
:param x: (b, c, t, h, w)
:param g: (b, g_d)
:return:
'''
output = self.operation_function(x, g)
return output
def _concatenation(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.relu(theta_x + phi_g, inplace=True)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
sigm_psi_f = F.sigmoid(self.psi(f))
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
def _concatenation_debug(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.softplus(theta_x + phi_g)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
sigm_psi_f = F.sigmoid(self.psi(f))
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
def _concatenation_residual(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
# phi => (b, g_d) -> (b, i_c)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
# Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = F.relu(theta_x + phi_g, inplace=True)
# psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
f = self.psi(f).view(batch_size, 1, -1)
sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:])
# upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
class GridAttentionBlock2D(_GridAttentionBlockND):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(2, 2)):
super(GridAttentionBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
)
class GridAttentionBlock3D(_GridAttentionBlockND):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(2,2,2)):
super(GridAttentionBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
)
class _GridAttentionBlockND_TORR(nn.Module):
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'):
super(_GridAttentionBlockND_TORR, self).__init__()
assert dimension in [2, 3]
assert mode in ['concatenation', 'concatenation_softmax',
'concatenation_sigmoid', 'concatenation_mean',
'concatenation_range_normalise', 'concatenation_mean_flow']
# Default parameter set
self.mode = mode
self.dimension = dimension
self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension
self.sub_sample_kernel_size = self.sub_sample_factor
# Number of channels (pixel dimensions)
self.in_channels = in_channels
self.gating_channels = gating_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
bn = nn.BatchNorm3d
self.upsample_mode = 'trilinear'
elif dimension == 2:
conv_nd = nn.Conv2d
bn = nn.BatchNorm2d
self.upsample_mode = 'bilinear'
else:
raise NotImplemented
# initialise id functions
# Theta^T * x_ij + Phi^T * gating_signal + bias
self.W = lambda x: x
self.theta = lambda x: x
self.psi = lambda x: x
self.phi = lambda x: x
self.nl1 = lambda x: x
if use_W:
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
bn(self.in_channels),
)
else:
self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
if use_theta:
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
if use_phi:
self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
if use_psi:
self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
if nonlinearity1:
if nonlinearity1 == 'relu':
self.nl1 = lambda x: F.relu(x, inplace=True)
if 'concatenation' in mode:
self.operation_function = self._concatenation
else:
raise NotImplementedError('Unknown operation function.')
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
if use_psi and self.mode == 'concatenation_sigmoid':
nn.init.constant(self.psi.bias.data, 3.0)
if use_psi and self.mode == 'concatenation_softmax':
nn.init.constant(self.psi.bias.data, 10.0)
# if use_psi and self.mode == 'concatenation_mean':
# nn.init.constant(self.psi.bias.data, 3.0)
# if use_psi and self.mode == 'concatenation_range_normalise':
# nn.init.constant(self.psi.bias.data, 3.0)
parallel = False
if parallel:
if use_W: self.W = nn.DataParallel(self.W)
if use_phi: self.phi = nn.DataParallel(self.phi)
if use_psi: self.psi = nn.DataParallel(self.psi)
if use_theta: self.theta = nn.DataParallel(self.theta)
def forward(self, x, g):
'''
:param x: (b, c, t, h, w)
:param g: (b, g_d)
:return:
'''
output = self.operation_function(x, g)
return output
def _concatenation(self, x, g):
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
#############################
# compute compatibility score
# theta => (b, c, t, h, w) -> (b, i_c, t, h, w)
# phi => (b, c, t, h, w) -> (b, i_c, t, h, w)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
# nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3)
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
f = theta_x + phi_g
f = self.nl1(f)
psi_f = self.psi(f)
############################################
# normalisation -- scale compatibility score
# psi^T . f -> (b, 1, t/s1, h/s2, w/s3)
if self.mode == 'concatenation_softmax':
sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2)
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_mean':
psi_f_flat = psi_f.view(batch_size, 1, -1)
psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6)
psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat)
sigm_psi_f = psi_f_flat / psi_f_sum
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_mean_flow':
psi_f_flat = psi_f.view(batch_size, 1, -1)
ss = psi_f_flat.shape
psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1)
psi_f_flat = psi_f_flat - psi_f_min
psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat)
sigm_psi_f = psi_f_flat / psi_f_sum
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_range_normalise':
psi_f_flat = psi_f.view(batch_size, 1, -1)
ss = psi_f_flat.shape
psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)
psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)
sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat)
sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])
elif self.mode == 'concatenation_sigmoid':
sigm_psi_f = F.sigmoid(psi_f)
else:
raise NotImplementedError
# sigm_psi_f is attention map! upsample the attentions and multiply
sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
y = sigm_psi_f.expand_as(x) * x
W_y = self.W(y)
return W_y, sigm_psi_f
class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(1,1), bn_layer=True,
use_W=True, use_phi=True, use_theta=True, use_psi=True,
nonlinearity1='relu'):
super(GridAttentionBlock2D_TORR, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer,
use_W=use_W,
use_phi=use_phi,
use_theta=use_theta,
use_psi=use_psi,
nonlinearity1=nonlinearity1)
class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR):
def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
sub_sample_factor=(1,1,1), bn_layer=True):
super(GridAttentionBlock3D_TORR, self).__init__(in_channels,
inter_channels=inter_channels,
gating_channels=gating_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class MultiAttentionBlock(nn.Module):
def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
super(MultiAttentionBlock, self).__init__()
self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,
inter_channels=inter_size, mode=nonlocal_mode,
sub_sample_factor=sub_sample_factor)
self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,
inter_channels=inter_size, mode=nonlocal_mode,
sub_sample_factor=sub_sample_factor)
self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(in_size),
nn.ReLU(inplace=True))
# initialise the blocks
for m in self.children():
if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue
init_weights(m, init_type='kaiming')
def forward(self, input, gating_signal):
gate_1, attention_1 = self.gate_block_1(input, gating_signal)
gate_2, attention_2 = self.gate_block_2(input, gating_signal)
return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)
if __name__ == '__main__':
from torch.autograd import Variable
mode_list = ['concatenation']
for mode in mode_list:
img = Variable(torch.rand(2, 16, 10, 10, 10))
gat = Variable(torch.rand(2, 64, 4, 4, 4))
net = GridAttentionBlock3D(in_channels=16, inter_channels=16, gating_channels=64, mode=mode, sub_sample_factor=(2,2,2))
out, sigma = net(img, gat)
print(out.size())
================================================
FILE: Models/layers/modules.py
================================================
import torch
import torch.nn as nn
def conv1x1(in_planes, out_planes, stride=1, bias=False):
"1x1 convolution"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
# conv_block(nn.Module) for U-net convolution block
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out, drop_out=False):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
self.dropout = drop_out
def forward(self, x):
x = self.conv(x)
if self.dropout:
x = nn.Dropout2d(0.5)(x)
return x
# # UpCat(nn.Module) for U-net UP convolution
class UpCat(nn.Module):
def __init__(self, in_feat, out_feat, is_deconv=True):
super(UpCat, self).__init__()
if is_deconv:
self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)
else:
self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, inputs, down_outputs):
# TODO: Upsampling required after deconv?
outputs = self.up(down_outputs)
offset = inputs.size()[3] - outputs.size()[3]
if offset == 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(
3).cuda()
outputs = torch.cat([outputs, addition], dim=3)
elif offset > 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()
outputs = torch.cat([outputs, addition], dim=3)
out = torch.cat([inputs, outputs], dim=1)
return out
# # UpCatconv(nn.Module) for up convolution
class UpCatconv(nn.Module):
def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False):
super(UpCatconv, self).__init__()
if is_deconv:
self.conv = conv_block(in_feat, out_feat, drop_out=drop_out)
self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)
else:
self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out)
self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, inputs, down_outputs):
# TODO: Upsampling required after deconv
outputs = self.up(down_outputs)
offset = inputs.size()[3] - outputs.size()[3]
if offset == 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(
3).cuda()
outputs = torch.cat([outputs, addition], dim=3)
elif offset > 1:
addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()
outputs = torch.cat([outputs, addition], dim=3)
out = self.conv(torch.cat([inputs, outputs], dim=1))
return out
# # UnetGridGatingSignal3(nn.Module)
class UnetGridGatingSignal3(nn.Module):
def __init__(self, in_size, out_size, kernel_size=(1, 1), is_batchnorm=True):
super(UnetGridGatingSignal3, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
nn.ReLU(inplace=True),
)
def forward(self, inputs):
outputs = self.conv1(inputs)
return outputs
class UnetDsv3(nn.Module):
def __init__(self, in_size, out_size, scale_factor):
super(UnetDsv3, self).__init__()
self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0),
nn.Upsample(size=scale_factor, mode='bilinear'), )
def forward(self, input):
return self.dsv(input)
================================================
FILE: Models/layers/nonlocal_layer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from Models.networks_other import init_weights
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
sub_sample_factor=4, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']
# print('Dimension: %d, mode: %s' % (dimension, mode))
self.mode = mode
self.dimension = dimension
self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool = nn.MaxPool3d
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool = nn.MaxPool2d
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool = nn.MaxPool1d
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant(self.W[1].weight, 0)
nn.init.constant(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
self.theta = None
self.phi = None
if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if mode in ['concatenation']:
self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)
self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)
elif mode in ['concat_proper', 'concat_proper_down']:
self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,
padding=0, bias=True)
if mode == 'embedded_gaussian':
self.operation_function = self._embedded_gaussian
elif mode == 'dot_product':
self.operation_function = self._dot_product
elif mode == 'gaussian':
self.operation_function = self._gaussian
elif mode == 'concatenation':
self.operation_function = self._concatenation
elif mode == 'concat_proper':
self.operation_function = self._concatenation_proper
elif mode == 'concat_proper_down':
self.operation_function = self._concatenation_proper_down
else:
raise NotImplementedError('Unknown operation function.')
if any(ss > 1 for ss in self.sub_sample_factor):
self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))
if self.phi is None:
self.phi = max_pool(kernel_size=sub_sample_factor)
else:
self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))
if mode == 'concat_proper_down':
self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))
# Initialise weights
for m in self.children():
init_weights(m, init_type='kaiming')
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
output = self.operation_function(x)
return output
def _embedded_gaussian(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
# (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _gaussian(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = x.view(batch_size, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample_factor > 1:
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
else:
phi_x = x.view(batch_size, self.in_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _dot_product(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
# theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw)
# phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw)
# f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw)
f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \
self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1))
f = F.relu(f, inplace=True)
# Normalise the relations
N = f.size(-1)
f_div_c = f / N
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation_proper(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
# phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
# f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
f = F.relu(f, inplace=True)
# psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
f = torch.squeeze(self.psi(f), dim=1)
# Normalise the relations
f_div_c = F.softmax(f, dim=1)
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation_proper_down(self, x):
batch_size = x.size(0)
# g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
theta_x = self.theta(x)
downsampled_size = theta_x.size()
theta_x = theta_x.view(batch_size, self.inter_channels, -1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
# phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
# f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
f = F.relu(f, inplace=True)
# psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
f = torch.squeeze(self.psi(f), dim=1)
# Normalise the relations
f_div_c = F.softmax(f, dim=1)
# g(x_j) * f(x_j, x_i)
# (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
y = torch.matmul(g_x, f_div_c)
y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:])
# upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3)
y = F.upsample(y, size=x.size()[2:], mode='trilinear')
# attention block output
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
super(NONLocalBlock3D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=3, mode=mode,
sub_sample_factor=sub_sample_factor,
bn_layer=bn_layer)
if __name__ == '__main__':
from torch.autograd import Variable
mode_list = ['concatenation']
#mode_list = ['embedded_gaussian', 'gaussian', 'dot_product', ]
for mode in mode_list:
print(mode)
img = Variable(torch.zeros(2, 4, 5))
net = NONLocalBlock1D(4, mode=mode, sub_sample_factor=2)
out = net(img)
print(out.size())
img = Variable(torch.zeros(2, 4, 5, 3))
net = NONLocalBlock2D(4, mode=mode, sub_sample_factor=1, bn_layer=False)
out = net(img)
print(out.size())
img = Variable(torch.zeros(2, 4, 5, 4, 5))
net = NONLocalBlock3D(4, mode=mode)
out = net(img)
print(out.size())
================================================
FILE: Models/layers/scale_attention_layer.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
def conv1x1(in_planes, out_planes, stride=1, bias=False):
"1x1 convolution"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
# # SE block add to U-net
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)
# # CBAM Convolutional block attention module
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
elif pool_type == 'lp':
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(lp_pool)
elif pool_type == 'lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
# scalecoe = F.sigmoid(channel_att_sum)
channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4)
avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2)
avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16)
scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale, scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
# spa_scale = scale.expand_as(x)
# print(spa_scale.shape)
return x * scale, scale
class SpatialAtten(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, stride=1):
super(SpatialAtten, self).__init__()
self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride,
padding=(kernel_size-1) // 2, relu=True)
self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride,
padding=0, relu=True, bn=False)
def forward(self, x):
residual = x
x_out = self.conv1(x)
x_out = self.conv2(x_out)
spatial_att = F.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3)
spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape(
spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4])
x_out = residual * spatial_att
x_out += residual
return x_out, spatial_att
class Scale_atten_block(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(Scale_atten_block, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio)
def forward(self, x):
x_out, ca_atten = self.ChannelGate(x)
if not self.no_spatial:
x_out, sa_atten = self.SpatialGate(x_out)
return x_out, ca_atten, sa_atten
class scale_atten_convblock(nn.Module):
def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False):
super(scale_atten_convblock, self).__init__()
# if stride != 1 or in_size != out_size:
# downsample = nn.Sequential(
# nn.Conv2d(in_size, out_size,
# kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(out_size),
# )
self.downsample = downsample
self.stride = stride
self.no_spatial = no_spatial
self.dropout = drop_out
self.relu = nn.ReLU(inplace=True)
self.conv3 = conv3x3(in_size, out_size)
self.bn3 = nn.BatchNorm2d(out_size)
if use_cbam:
self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size
else:
self.cbam = None
def forward(self, x):
residual = x
if self.downsample is not None:
residual = self.downsample(x)
if not self.cbam is None:
out, scale_c_atten, scale_s_atten = self.cbam(x)
# scale_c_atten = nn.Sigmoid()(scale_c_atten)
# scale_s_atten = nn.Sigmoid()(scale_s_atten)
# scale_atten = channel_atten_c * spatial_atten_s
# scale_max = torch.argmax(scale_atten, dim=1, keepdim=True)
# scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8)
# scale_max_soft = scale_max_soft.permute(0, 3, 1, 2)
# scale_atten_soft = scale_atten * scale_max_soft
out += residual
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.relu(out)
if self.dropout:
out = nn.Dropout2d(0.5)(out)
return out
================================================
FILE: Models/networks/network.py
================================================
import torch
import torch.nn as nn
from Models.layers.modules import conv_block, UpCat, UpCatconv, UnetDsv3, UnetGridGatingSignal3
from Models.layers.grid_attention_layer import GridAttentionBlock2D, MultiAttentionBlock
from Models.layers.channel_attention_layer import SE_Conv_Block
from Models.layers.scale_attention_layer import scale_atten_convblock
from Models.layers.nonlocal_layer import NONLocalBlock2D
class Comprehensive_Atten_Unet(nn.Module):
def __init__(self, args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True,
nonlocal_mode='concatenation', attention_dsample=(1, 1)):
super(Comprehensive_Atten_Unet, self).__init__()
self.args = args
self.is_deconv = is_deconv
self.in_channels = in_ch
self.num_classes = n_classes
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
self.out_size = args.out_size
filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]
# downsampling
self.conv1 = conv_block(self.in_channels, filters[0])
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = conv_block(filters[0], filters[1])
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = conv_block(filters[1], filters[2])
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = conv_block(filters[2], filters[3], drop_out=True)
self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.center = conv_block(filters[3], filters[4], drop_out=True)
# attention blocks
# self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1],
# inter_channels=filters[0])
self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)
self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)
self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4)
# upsampling
self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv)
self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv)
self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv)
self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv)
self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True)
self.up3 = SE_Conv_Block(filters[3], filters[2])
self.up2 = SE_Conv_Block(filters[2], filters[1])
self.up1 = SE_Conv_Block(filters[1], filters[0])
# deep supervision
self.dsv4 = UnetDsv3(in_size=filters[3], out_size=4, scale_factor=self.out_size)
self.dsv3 = UnetDsv3(in_size=filters[2], out_size=4, scale_factor=self.out_size)
self.dsv2 = UnetDsv3(in_size=filters[1], out_size=4, scale_factor=self.out_size)
self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=4, kernel_size=1)
self.scale_att = scale_atten_convblock(in_size=16, out_size=4)
# final conv (without any concat)
self.final = nn.Sequential(nn.Conv2d(4, n_classes, kernel_size=1), nn.Softmax2d())
def forward(self, inputs):
# Feature Extraction
conv1 = self.conv1(inputs)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
# Gating Signal Generation
center = self.center(maxpool4)
# Attention Mechanism
# Upscaling Part (Decoder)
up4 = self.up_concat4(conv4, center)
g_conv4 = self.nonlocal4_2(up4)
up4, att_weight4 = self.up4(g_conv4)
g_conv3, att3 = self.attentionblock3(conv3, up4)
# atten3_map = att3.cpu().detach().numpy().astype(np.float)
# atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2],
# 300 / atten3_map.shape[3]], order=0)
up3 = self.up_concat3(g_conv3, up4)
up3, att_weight3 = self.up3(up3)
g_conv2, att2 = self.attentionblock2(conv2, up3)
# atten2_map = att2.cpu().detach().numpy().astype(np.float)
# atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2],
# 300 / atten2_map.shape[3]], order=0)
up2 = self.up_concat2(g_conv2, up3)
up2, att_weight2 = self.up2(up2)
# g_conv1, att1 = self.attentionblock1(conv1, up2)
# atten1_map = att1.cpu().detach().numpy().astype(np.float)
# atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2],
# 300 / atten1_map.shape[3]], order=0)
up1 = self.up_concat1(conv1, up2)
up1, att_weight1 = self.up1(up1)
# Deep Supervision
dsv4 = self.dsv4(up4)
dsv3 = self.dsv3(up3)
dsv2 = self.dsv2(up2)
dsv1 = self.dsv1(up1)
dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1)
out = self.scale_att(dsv_cat)
out = self.final(out)
return out
================================================
FILE: Models/networks_other.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
import time
import numpy as np
###############################################################################
# Functions
###############################################################################
def weights_init_normal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
#print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def adjust_learning_rate(optimizer, lr):
"""Sets the learning rate to a fixed number"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def get_scheduler(optimizer, opt):
print('opt.lr_policy = [{}]'.format(opt.lr_policy))
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.5)
elif opt.lr_policy == 'step2':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
print('schedular=plateau')
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.01, patience=5)
elif opt.lr_policy == 'plateau2':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'step_warmstart':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 100:
lr_l = 1
elif 100 <= epoch < 200:
lr_l = 0.1
elif 200 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step_warmstart2':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 50:
lr_l = 1
elif 50 <= epoch < 100:
lr_l = 0.1
elif 100 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def get_n_parameters(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
return num_params
def measure_fp_bp_time(model, x, y):
# synchronize gpu time and measure fp
torch.cuda.synchronize()
t0 = time.time()
y_pred = model(x)
torch.cuda.synchronize()
elapsed_fp = time.time() - t0
if isinstance(y_pred, tuple):
y_pred = sum(y_p.sum() for y_p in y_pred)
else:
y_pred = y_pred.sum()
# zero gradients, synchronize time and measure
model.zero_grad()
t0 = time.time()
#y_pred.backward(y)
y_pred.backward()
torch.cuda.synchronize()
elapsed_bp = time.time() - t0
return elapsed_fp, elapsed_bp
def benchmark_fp_bp_time(model, x, y, n_trial=1000):
# transfer the model on GPU
model.cuda()
# DRY RUNS
for i in range(10):
_, _ = measure_fp_bp_time(model, x, y)
print('DONE WITH DRY RUNS, NOW BENCHMARKING')
# START BENCHMARKING
t_forward = []
t_backward = []
print('trial: {}'.format(n_trial))
for i in range(n_trial):
t_fp, t_bp = measure_fp_bp_time(model, x, y)
t_forward.append(t_fp)
t_backward.append(t_bp)
# free memory
del model
return np.mean(t_forward), np.mean(t_backward)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
================================================
FILE: README.md
================================================
## CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation
This repository provides the code for "CA-Net: Comprehensive attention Convolutional Neural Networks for Explainable Medical Image Segmentation". Our work now is available on [Arxiv][paper_link]. Our work is accepted by [TMI][tmi_link].
[paper_link]:https://arxiv.org/pdf/2009.10549.pdf
[tmi_link]:https://ieeexplore.ieee.org/document/9246575

Fig. 1. Structure of CA-Net.

Fig. 2. Skin lesion segmentation.

Fig. 3. Placenta and fetal brain segmentation.
### Requirementss
Some important required packages include:
* [Pytorch][torch_link] version >=0.4.1.
* Visdom
* Python == 3.7
* Some basic python packages such as Numpy.
Follow official guidance to install [Pytorch][torch_link].
[torch_link]:https://pytorch.org/
## Usages
### For skin lesion segmentation
1. First, you can download the dataset at [ISIC 2018][data_link]. We only used ISIC 2018 task1 training dataset, To preprocess the dataset and save as ".npy", run:
[data_link]:https://challenge.isic-archive.com/data#2018
```
python isic_preprocess.py
```
2. For conducting 5-fold cross-validation, split the preprocessed data into 5 fold and save their filenames. run:
```
python create_folder.py
```
2. To train CA-Net in ISIC 2018 (taking 1st-fold validation for example), run:
```
python main.py --data ISIC2018 --val_folder folder1 --id Comp_Atten_Unet
```
3. To evaluate the trained model in ISIC 2018 (we added a test data in folder0, testing the 0th-fold validation for example), run:
```
python validation.py --data ISIC2018 --val_folder folder0 --id Comp_Atten_Unet
```
Our experimental results are shown in the table:

4. You can save the attention weight map in the middle step of the network to '/result' folder. Visualizing the attention weight above the original images, run:
```
python show_fused_heatmap.py
```
Visualzation of spatial attention weight map:

Visualzation of scale attention weight map:

## Citation
If you find our work is helpful for your research, please consider to cite:
```
@article{gu2020net,
title={CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation},
author={Gu, Ran and Wang, Guotai and Song, Tao and Huang, Rui and Aertsen, Michael and Deprest, Jan and Ourselin, S{\'e}bastien and Vercauteren, Tom and Zhang, Shaoting},
journal={IEEE Transactions on Medical Imaging},
year={2020},
publisher={IEEE}
}
```
## Acknowledgement
Part of the code is revised from [Attention-Gate-Networks][AG].
[AG]:https://github.com/ozan-oktay/Attention-Gated-Networks
================================================
FILE: create_folder.py
================================================
import os
import numpy
from random import shuffle
PATH = './data/ISIC2018_Task1_npy_all/image'
SAVE_PATH = './Datasets'
def create_5_floder(folder, save_foler):
file_list = os.listdir(folder)
shuffle(file_list)
for i in range(5):
if i != 0:
pre_test_list = file_list[0:i*518]
else:
pre_test_list = []
test_list = file_list[i*518:(i+1)*518]
if i < 4:
valid_list = file_list[(i+1)*518:(i+1)*518+260]
train_list = file_list[(i+1)*518+260:] + pre_test_list
else:
valid_list = file_list[-4:] + file_list[:256]
train_list = file_list[256:i*518]
if not os.path.isdir(save_foler + '/folder'+str(i+1)):
os.makedirs(save_foler + '/folder'+str(i+1))
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_train.list'), train_list)
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_validation.list'), valid_list)
text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_test.list'), test_list)
def text_save(filename, data): # filename: path to write CSV, data: data list to be written.
file = open(filename, 'w+')
for i in range(len(data)):
s = str(data[i]).replace('[', '').replace(']', '')
s = s.replace("'", '').replace(',', '') + '\n'
file.write(s)
file.close()
print("Save {} successfully".format(filename.split('/')[-1]))
if __name__ == "__main__":
create_5_floder(PATH, SAVE_PATH)
================================================
FILE: isic_preprocess.py
================================================
#!/usr/bin/python3
# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection
# -*- coding: utf-8 -*-
# @Author : Ran Gu
import os
import random
import numpy as np
from skimage import io
from PIL import Image
root_dir = 'gr/Skin Segmentation' # change it in your saved original data path
save_dir = './data/ISIC2018_Task1_npy_all'
if __name__ == '__main__':
imgfile = os.path.join(root_dir, 'ISIC2018_Task1-2_Training_Input')
labfile = os.path.join(root_dir, 'ISIC2018_Task1_Training_GroundTruth')
filename = sorted([os.path.join(imgfile, x) for x in os.listdir(imgfile) if x.endswith('.jpg')])
random.shuffle(filename)
labname = [filename[x].replace('ISIC2018_Task1-2_Training_Input', 'ISIC2018_Task1_Training_GroundTruth'
).replace('.jpg', '_segmentation.png') for x in range(len(filename))]
if not os.path.isdir(save_dir):
os.makedirs(save_dir+'/image')
os.makedirs(save_dir+'/label')
for i in range(len(filename)):
fname = filename[i].rsplit('/', maxsplit=1)[-1].split('.')[0]
lname = labname[i].rsplit('/', maxsplit=1)[-1].split('.')[0]
image = Image.open(filename[i])
label = Image.open(labname[i])
image = image.resize((342, 256))
label = label.resize((342, 256))
image = np.array(image)
label = np.array(label)
images_img_filename = os.path.join(save_dir, 'image', fname)
labels_img_filename = os.path.join(save_dir, 'label', lname)
np.save(images_img_filename, image)
np.save(labels_img_filename, label)
print('Successfully saved preprocessed data')
================================================
FILE: main.py
================================================
#!/usr/bin/python3
# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection
# -*- coding: utf-8 -*-
# @Author : Ran Gu
import os
import torch
import math
import visdom
import torch.utils.data as Data
import argparse
import numpy as np
from tqdm import tqdm
from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform
from Models.networks.network import Comprehensive_Atten_Unet
from utils.dice_loss import SoftDiceLoss, get_soft_label, val_dice_fetus, val_dice_isic
from utils.dice_loss import Intersection_over_Union_fetus, Intersection_over_Union_isic
from utils.evaluation import AverageMeter
from utils.binary import assd
from torch.optim.lr_scheduler import StepLR
Test_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'ISIC2018': ISIC2018_transform}
def train(train_loader, model, criterion, optimizer, args, epoch):
losses = AverageMeter()
model.train()
for step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
image = x.float().cuda()
target = y.float().cuda()
output = model(image) # model output
target_soft = get_soft_label(target, args.num_classes) # get soft label
loss = criterion(output, target_soft, args.num_classes) # the dice losses
losses.update(loss.data, image.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % (math.ceil(float(len(train_loader.dataset))/args.batch_size)) == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(train_loader.dataset),
100. * step / len(train_loader), losses=losses))
print('The average loss:{losses.avg:.4f}'.format(losses=losses))
return losses.avg
def valid_fetus(valid_loader, model, criterion, optimizer, args, epoch, minloss):
val_losses = AverageMeter()
val_placenta_dice = AverageMeter()
val_brain_dice = AverageMeter()
model.eval()
for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
image = t.float().cuda()
target = k.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes) # get soft label
target_soft = get_soft_label(target, args.num_classes)
val_loss = criterion(output, target_soft, args.num_classes) # the dice losses
val_losses.update(val_loss.data, image.size(0))
placenta, brain = val_dice_fetus(output_soft, target_soft, args.num_classes) # the dice score
val_placenta_dice.update(placenta.data, image.size(0))
val_brain_dice.update(brain.data, image.size(0))
if step % (math.ceil(float(len(valid_loader.dataset))/args.batch_size)) == 0:
print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader), losses=val_losses))
print('The Placenta Mean Average Dice score: {placenta.avg: .4f}; '
'The Brain Mean Average Dice score: {brain.avg: .4f}; '
'The Average Loss score: {loss.avg: .4f}'.format(
placenta=val_placenta_dice, brain=val_brain_dice, loss=val_losses))
if val_losses.avg < min(minloss):
minloss.append(val_losses.avg)
print(minloss)
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
print('the best model will be saved at {}'.format(modelname))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, modelname)
return val_losses.avg, val_placenta_dice.avg, val_brain_dice.avg
def valid_isic(valid_loader, model, criterion, optimizer, args, epoch, minloss):
val_losses = AverageMeter()
val_isic_dice = AverageMeter()
model.eval()
for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
image = t.float().cuda()
target = k.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
val_loss = criterion(output, target_soft, args.num_classes) # the dice losses
val_losses.update(val_loss.data, image.size(0))
isic = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice score
val_isic_dice.update(isic.data, image.size(0))
if step % (math.ceil(float(len(valid_loader.dataset)) / args.batch_size)) == 0:
print('Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {losses.avg:.6f}'.format(
epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader),
losses=val_losses))
print('The ISIC Mean Average Dice score: {isic.avg: .4f}; '
'The Average Loss score: {loss.avg: .4f}'.format(
isic=val_isic_dice, loss=val_losses))
if val_losses.avg < min(minloss):
minloss.append(val_losses.avg)
print(minloss)
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
print('the best model will be saved at {}'.format(modelname))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, modelname)
return val_losses.avg, val_isic_dice.avg
def test_fetus(test_loader, model, args):
placenta_dice = []
brain_dice = []
placenta_iou = []
brain_iou = []
placenta_assd = []
brain_assd = []
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
model.eval()
for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
image = img.float().cuda()
target = lab.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
# input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)
label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)
placenta_b_dice, brain_b_dice = val_dice_fetus(output_soft, target_soft, args.num_classes) # the dice accuracy
placenta_b_iou, brain_b_iou = Intersection_over_Union_fetus(output_soft, target_soft, args.num_classes) # the iou accuracy
placenta_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])
brain_b_asd = assd(output_arr[:, :, :, 2], label_arr[:, :, :, 2])
pla_dice_np = placenta_b_dice.data.cpu().numpy()
bra_iou_np = brain_b_iou.data.cpu().numpy()
bra_dice_np = brain_b_dice.data.cpu().numpy()
pla_iou_np = placenta_b_iou.data.cpu().numpy()
placenta_dice.append(pla_dice_np)
brain_dice.append(bra_dice_np)
placenta_iou.append(pla_iou_np)
brain_iou.append(bra_iou_np)
placenta_assd.append(placenta_b_asd)
brain_assd.append(brain_b_asd)
placenta_dice_mean = np.average(placenta_dice)
placenta_dice_std = np.std(placenta_dice)
brain_dice_mean = np.average(brain_dice)
brain_dice_std = np.std(brain_dice)
placenta_iou_mean = np.average(placenta_iou)
placenta_iou_std = np.std(placenta_iou)
brain_iou_mean = np.average(brain_iou)
brain_iou_std = np.std(brain_iou)
placenta_assd_mean = np.average(placenta_assd)
placenta_assd_std = np.std(placenta_assd)
brain_assd_mean = np.average(brain_assd)
brain_assd_std = np.std(brain_assd)
print('The Placenta mean Accuracy: {placenta_dice_mean: .4f}; The Placenta Accuracy std: {placenta_dice_std: .4f}; '
'The Brain mean Accuracy: {brain_dice_mean: .4f}; The Brain Accuracy std: {brain_dice_std: .4f}'.format(
placenta_dice_mean=placenta_dice_mean, placenta_dice_std=placenta_dice_std,
brain_dice_mean=brain_dice_mean, brain_dice_std=brain_dice_std))
print('The Placenta mean IoU: {placenta_iou_mean: .4f}; The Placenta IoU std: {placenta_iou_std: .4f}; '
'The Brain mean IoU: {brain_iou_mean: .4f}; The Brain IoU std: {brain_iou_std: .4f}'.format(
placenta_iou_mean=placenta_iou_mean, placenta_iou_std=placenta_iou_std,
brain_iou_mean=brain_iou_mean, brain_iou_std=brain_iou_std))
print('The Placenta mean assd: {placenta_asd_mean: .4f}; The Placenta assd std: {placenta_asd_std: .4f}; '
'The Brain mean assd: {brain_asd_mean: .4f}; The Brain assd std: {brain_asd_std: .4f}'.format(
placenta_asd_mean=placenta_assd_mean, placenta_asd_std=placenta_assd_std,
brain_asd_mean=brain_assd_mean, brain_asd_std=brain_assd_std))
def test_isic(test_loader, model, args):
isic_dice = []
isic_iou = []
isic_assd = []
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
model.eval()
for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):
image = img.float().cuda()
target = lab.float().cuda()
output = model(image) # model output
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)
output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)
isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice accuracy
isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes) # the iou accuracy
isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1]) # the assd
dice_np = isic_b_dice.data.cpu().numpy()
iou_np = isic_b_iou.data.cpu().numpy()
isic_dice.append(dice_np)
isic_iou.append(iou_np)
isic_assd.append(isic_b_asd)
isic_dice_mean = np.average(isic_dice)
isic_dice_std = np.std(isic_dice)
isic_iou_mean = np.average(isic_iou)
isic_iou_std = np.std(isic_iou)
isic_assd_mean = np.average(isic_assd)
isic_assd_std = np.std(isic_assd)
print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The Placenta Accuracy std: {isic_dice_std: .4f}'.format(
isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(
isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))
def main(args):
minloss = [1.0]
start_epoch = args.start_epoch
# loading the dataset
print('loading the {0},{1},{2} dataset ...'.format('train', 'validation', 'test'))
trainset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='train',
transform=Test_Transform[args.data])
validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',
transform=Test_Transform[args.data])
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',
transform=Test_Transform[args.data])
trainloader = Data.DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
validloader = Data.DataLoader(dataset=validset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
print('Loading is done\n')
# Define model
if args.data == 'Fetus':
args.num_input = 1
args.num_classes = 3
args.out_size = (256, 256)
elif args.data == 'ISIC2018':
args.num_input = 3
args.num_classes = 2
args.out_size = (224, 300)
model = Test_Model[args.id](args, args.num_input, args.num_classes)
if torch.cuda.is_available():
print('We can use', torch.cuda.device_count(), 'GPUs to train the network')
model = model.cuda()
# model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
# collect the number of parameters in the network
print("------------------------------------------")
print("Network Architecture of Model AttU_Net:")
num_para = 0
for name, param in model.named_parameters():
num_mul = 1
for x in param.size():
num_mul *= x
num_para += num_mul
print(model)
print("Number of trainable parameters {0} in Model {1}".format(num_para, args.id))
print("------------------------------------------")
# Define optimizers and loss function
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr_rate,
weight_decay=args.weight_decay) # optimize all model parameters
criterion = SoftDiceLoss()
scheduler = StepLR(optimizer, step_size=256, gamma=0.5)
# resume
if args.resume:
if os.path.isfile(args.resume):
print("=> Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded checkpoint (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(args.resume))
# visualiser
vis = visdom.Visdom(env='CA-net')
print("Start training ...")
for epoch in range(start_epoch + 1, args.epochs + 1):
scheduler.step()
train_avg_loss = train(trainloader, model, criterion, optimizer, args, epoch)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([train_avg_loss]),
win=args.id + args.data,
update='append',
opts=dict(title=args.id+'_'+args.data,
xlabel='Epochs',
ylabel='Train_avg_loss'))
if args.data == 'Fetus':
val_avg_loss, val_placenta_dice, val_brain_dice = valid_fetus(validloader, model, criterion,
optimizer, args, epoch, minloss)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),
win=args.id + args.data + 'valid_avg',
name='loss',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_placenta_dice]),
win=args.id + args.data + 'valid_avg',
name='placenta_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_brain_dice]),
win=args.id + args.data + 'valid_avg',
name='brain_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
elif args.data == 'ISIC2018':
val_avg_loss, val_isic_dice = valid_isic(validloader, model, criterion, optimizer, args, epoch, minloss)
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),
win=args.id + args.data + 'valid_avg',
name='loss',
update='append',
opts=dict(title=args.id + '_' + args.data + '_',
xlabel='Epochs',
ylabel='Dice&loss'))
vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_isic_dice]),
win=args.id + args.data + 'valid_avg',
name='isic_dice',
update='append',
opts=dict(title=args.id + '_' + args.data,
xlabel='Epochs',
ylabel='Dice&loss'))
# save models
if epoch > args.particular_epoch:
if epoch % args.save_epochs_steps == 0:
filename = args.ckpt + '/' + str(epoch) + '_' + args.data + '_checkpoint.pth.tar'
print('the model will be saved at {}'.format(filename))
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}
torch.save(state, filename)
print('Training Done! Start testing')
if args.data == 'Fetus':
test_fetus(testloader, model, args)
elif args.data == 'ISIC2018':
test_isic(testloader, model, args)
print('Testing Done!')
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')
# Model related arguments
parser.add_argument('--id', default='Comp_Atten_Unet',
help='a name for identitying the model. Choose from the following options: Unet')
# Path related arguments
parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',
help='root directory of data')
parser.add_argument('--ckpt', default='./saved_models',
help='folder to output checkpoints')
# optimization related arguments
parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--start_epoch', default=0, type=int,
help='epoch to start training. useful if continue from a checkpoint')
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
help='input batch size for training (default: 12)')
parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')
parser.add_argument('--particular_epoch', default=30, type=int,
help='after this number, we will save models more frequently')
parser.add_argument('--save_epochs_steps', default=200, type=int,
help='frequency to save models after a particular number of epochs')
parser.add_argument('--resume', default='',
help='the checkpoint that resumes from')
# other arguments
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 300), help='the output image size')
parser.add_argument('--val_folder', default='folder0', type=str,
help='which cross validation folder')
args = parser.parse_args()
print("Input arguments:")
for key, value in vars(args).items():
print("{:16} {}".format(key, value))
args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)
print('Models are saved at %s' % (args.ckpt))
if not os.path.isdir(args.ckpt):
os.makedirs(args.ckpt)
if args.start_epoch > 1:
args.resume = args.ckpt + '/' + str(args.start_epoch) + '_' + args.data + '_checkpoint.pth.tar'
main(args)
================================================
FILE: show_fused_heatmap.py
================================================
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def map_scalar_to_color(x):
x_list = [0.0, 0.25, 0.5, 0.75, 1.0]
c_list = [[0, 0, 255],
[0, 255, 255],
[0, 255, 0],
[255, 255, 0],
[255, 0, 0]]
for i in range(len(x_list)):
if(x <= x_list[i + 1]):
x0 = x_list[i]
x1 = x_list[i + 1]
c0 = c_list[i]
c1 = c_list[i + 1]
alpha = (x - x0)/(x1 - x0)
c = [c0[j]*(1 - alpha) + c1[j] * alpha for j in range(3)]
c = [int(item) for item in c]
return tuple(c)
def get_fused_heat_map(image, att):
[H, W] = image.size
img = Image.new('RGB', image.size, (255, 0, 0))
for i in range(H):
for j in range(W):
p0 = image.getpixel((i,j))
alpha = att.getpixel((i,j))
p1 = map_scalar_to_color(alpha)
alpha = 0.3 + alpha*0.5
p = [int(p0[c] * (1 - alpha) + p1[c]*alpha) for c in range(3)]
p = tuple(p)
img.putpixel((i, j), p)
return img
if __name__ == "__main__":
image_name = "./result/atten_map/ISIC_0015937.jpg"
scalar_name = "./result/atten_map/25_2_8_wgt"
save_name = "./result/atten_map/15937_wgt3_fused"
img = Image.open(image_name)
# img = np.load(image_name)
# img = Image.fromarray(np.uint8(img*255))
# load the scalar map, and normalize the inteinsty to 0 - 1
scl = Image.open(scalar_name).convert('L')
scl = np.asarray(scl)
scl = cv2.resize(scl, dsize=(img.size[0], img.size[1]), interpolation=cv2.INTER_NEAREST)
scl_norm = np.asarray(scl, np.float32)/255
scl_norm = Image.fromarray(scl_norm)
# convert the scalar map to heat map, and fuse it with the original image
img_scl = get_fused_heat_map(img, scl_norm)
# img_scl.save(save_name, format='png')
plt.imshow(img_scl), plt.title('fused result')
# plt.colorbar()
plt.show()
================================================
FILE: utils/binary.py
================================================
# Copyright (C) 2013 Oskar Maier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# author Oskar Maier
# version r0.1.1
# since 2014-03-13
# status Release
# build-in modules
# third-party modules
import numpy
from scipy.ndimage import _ni_support
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
generate_binary_structure
from scipy.ndimage.measurements import label, find_objects
from scipy.stats import pearsonr
# own modules
# code
def dc(result, reference):
r"""
Dice coefficient
Computes the Dice coefficient (also known as Sorensen index) between the binary
objects in two images.
The metric is defined as
.. math::
DC=\frac{2|A\cap B|}{|A|+|B|}
, where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects).
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
dc : float
The Dice coefficient between the object(s) in ```result``` and the
object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap).
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
intersection = numpy.count_nonzero(result & reference)
size_i1 = numpy.count_nonzero(result)
size_i2 = numpy.count_nonzero(reference)
try:
dc = 2. * intersection / float(size_i1 + size_i2)
except ZeroDivisionError:
dc = 0.0
return dc
def jc(result, reference):
"""
Jaccard coefficient
Computes the Jaccard coefficient between the binary objects in two images.
Parameters
----------
result: array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference: array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
jc: float
The Jaccard coefficient between the object(s) in `result` and the
object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap).
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
intersection = numpy.count_nonzero(result & reference)
union = numpy.count_nonzero(result | reference)
jc = float(intersection) / float(union)
return jc
def precision(result, reference):
"""
Precison.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
precision : float
The precision between two binary datasets, here mostly binary objects in images,
which is defined as the fraction of retrieved instances that are relevant. The
precision is not symmetric.
See also
--------
:func:`recall`
Notes
-----
Not symmetric. The inverse of the precision is :func:`recall`.
High precision means that an algorithm returned substantially more relevant results than irrelevant.
References
----------
.. [1] http://en.wikipedia.org/wiki/Precision_and_recall
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tp = numpy.count_nonzero(result & reference)
fp = numpy.count_nonzero(result & ~reference)
try:
precision = tp / float(tp + fp)
except ZeroDivisionError:
precision = 0.0
return precision
def recall(result, reference):
"""
Recall.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
recall : float
The recall between two binary datasets, here mostly binary objects in images,
which is defined as the fraction of relevant instances that are retrieved. The
recall is not symmetric.
See also
--------
:func:`precision`
Notes
-----
Not symmetric. The inverse of the recall is :func:`precision`.
High recall means that an algorithm returned most of the relevant results.
References
----------
.. [1] http://en.wikipedia.org/wiki/Precision_and_recall
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tp = numpy.count_nonzero(result & reference)
fn = numpy.count_nonzero(~result & reference)
try:
recall = tp / float(tp + fn)
except ZeroDivisionError:
recall = 0.0
return recall
def sensitivity(result, reference):
"""
Sensitivity.
Same as :func:`recall`, see there for a detailed description.
See also
--------
:func:`specificity`
"""
return recall(result, reference)
def specificity(result, reference):
"""
Specificity.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
specificity : float
The specificity between two binary datasets, here mostly binary objects in images,
which denotes the fraction of correctly returned negatives. The
specificity is not symmetric.
See also
--------
:func:`sensitivity`
Notes
-----
Not symmetric. The completment of the specificity is :func:`sensitivity`.
High recall means that an algorithm returned most of the irrelevant results.
References
----------
.. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity
.. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
tn = numpy.count_nonzero(~result & ~reference)
fp = numpy.count_nonzero(result & ~reference)
try:
specificity = tn / float(tn + fp)
except ZeroDivisionError:
specificity = 0.0
return specificity
def true_negative_rate(result, reference):
"""
True negative rate.
Same as :func:`specificity`, see there for a detailed description.
See also
--------
:func:`true_positive_rate`
:func:`positive_predictive_value`
"""
return specificity(result, reference)
def true_positive_rate(result, reference):
"""
True positive rate.
Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description.
See also
--------
:func:`positive_predictive_value`
:func:`true_negative_rate`
"""
return recall(result, reference)
def positive_predictive_value(result, reference):
"""
Positive predictive value.
Same as :func:`precision`, see there for a detailed description.
See also
--------
:func:`true_positive_rate`
:func:`true_negative_rate`
"""
return precision(result, reference)
def hd(result, reference, voxelspacing=None, connectivity=1):
"""
Hausdorff Distance.
Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two
images. It is defined as the maximum surface distance between the objects.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
Note that the connectivity influences the result in the case of the Hausdorff distance.
Returns
-------
hd : float
The symmetric Hausdorff Distance between the object(s) in ```result``` and the
object(s) in ```reference```. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`assd`
:func:`asd`
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()
hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()
hd = max(hd1, hd2)
return hd
def hd95(result, reference, voxelspacing=None, connectivity=1):
"""
95th percentile of the Hausdorff Distance.
Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two
images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is
commonly used in Biomedical Segmentation challenges.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
Note that the connectivity influences the result in the case of the Hausdorff distance.
Returns
-------
hd : float
The symmetric Hausdorff Distance between the object(s) in ```result``` and the
object(s) in ```reference```. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`hd`
Notes
-----
This is a real metric. The binary images can therefore be supplied in any order.
"""
hd1 = __surface_distances(result, reference, voxelspacing, connectivity)
hd2 = __surface_distances(reference, result, voxelspacing, connectivity)
hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95)
return hd95
def assd(result, reference, voxelspacing=None, connectivity=1):
"""
Average symmetric surface distance.
Computes the average symmetric surface distance (ASD) between the binary objects in
two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
assd : float
The average symmetric surface distance between the object(s) in ``result`` and the
object(s) in ``reference``. The distance unit is the same as for the spacing of
elements along each dimension, which is usually given in mm.
See also
--------
:func:`asd`
:func:`hd`
Notes
-----
This is a real metric, obtained by calling and averaging
>>> asd(result, reference)
and
>>> asd(reference, result)
The binary images can therefore be supplied in any order.
"""
assd = numpy.mean( (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)) )
return assd
def asd(result, reference, voxelspacing=None, connectivity=1):
"""
Average surface distance metric.
Computes the average surface distance (ASD) between the binary objects in two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
asd : float
The average surface distance between the object(s) in ``result`` and the
object(s) in ``reference``. The distance unit is the same as for the spacing
of elements along each dimension, which is usually given in mm.
See also
--------
:func:`assd`
:func:`hd`
Notes
-----
This is not a real metric, as it is directed. See `assd` for a real metric of this.
The method is implemented making use of distance images and simple binary morphology
to achieve high computational speed.
Examples
--------
The `connectivity` determines what pixels/voxels are considered the surface of a
binary object. Take the following binary image showing a cross
>>> from scipy.ndimage.morphology import generate_binary_structure
>>> cross = generate_binary_structure(2, 1)
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
With `connectivity` set to `1` a 4-neighbourhood is considered when determining the
object surface, resulting in the surface
.. code-block:: python
array([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get:
.. code-block:: python
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
, as a diagonal connection does no longer qualifies as valid object surface.
This influences the results `asd` returns. Imagine we want to compute the surface
distance of our cross to a cube-like object:
>>> cube = generate_binary_structure(2, 1)
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
, which surface is, independent of the `connectivity` value set, always
.. code-block:: python
array([[1, 1, 1],
[1, 0, 1],
[1, 1, 1]])
Using a `connectivity` of `1` we get
>>> asd(cross, cube, connectivity=1)
0.0
while a value of `2` returns us
>>> asd(cross, cube, connectivity=2)
0.20000000000000001
due to the center of the cross being considered surface as well.
"""
sds = __surface_distances(result, reference, voxelspacing, connectivity)
asd = sds.mean()
return asd
def ravd(result, reference):
"""
Relative absolute volume difference.
Compute the relative absolute volume difference between the (joined) binary objects
in the two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
Returns
-------
ravd : float
The relative absolute volume difference between the object(s) in ``result``
and the object(s) in ``reference``. This is a percentage value in the range
:math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score.
Raises
------
RuntimeError
If the reference object is empty.
See also
--------
:func:`dc`
:func:`precision`
:func:`recall`
Notes
-----
This is not a real metric, as it is directed. Negative values denote a smaller
and positive values a larger volume than the reference.
This implementation does not check, whether the two supplied arrays are of the same
size.
Examples
--------
Considering the following inputs
>>> import numpy
>>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]])
>>> arr1
array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]])
>>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]])
>>> arr2
array([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
comparing `arr1` to `arr2` we get
>>> ravd(arr1, arr2)
-0.2
and reversing the inputs the directivness of the metric becomes evident
>>> ravd(arr2, arr1)
0.25
It is important to keep in mind that a perfect score of `0` does not mean that the
binary objects fit exactely, as only the volumes are compared:
>>> arr1 = numpy.asarray([1,0,0])
>>> arr2 = numpy.asarray([0,0,1])
>>> ravd(arr1, arr2)
0.0
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
vol1 = numpy.count_nonzero(result)
vol2 = numpy.count_nonzero(reference)
if 0 == vol2:
raise RuntimeError('The second supplied array does not contain any binary object.')
return (vol1 - vol2) / float(vol2)
def volume_correlation(results, references):
r"""
Volume correlation.
Computes the linear correlation in binary object volume between the
contents of the successive binary images supplied. Measured through
the Pearson product-moment correlation coefficient.
Parameters
----------
results : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
references : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
The order must be the same as for ``results``.
Returns
-------
r : float
The correlation coefficient between -1 and 1.
p : float
The two-side p value.
"""
results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
results_volumes = [numpy.count_nonzero(r) for r in results]
references_volumes = [numpy.count_nonzero(r) for r in references]
return pearsonr(results_volumes, references_volumes) # returns (Pearson'
def volume_change_correlation(results, references):
r"""
Volume change correlation.
Computes the linear correlation of change in binary object volume between
the contents of the successive binary images supplied. Measured through
the Pearson product-moment correlation coefficient.
Parameters
----------
results : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
references : sequence of array_like
Ordered list of input data containing objects. Each array_like will be
converted into binary: background where 0, object everywhere else.
The order must be the same as for ``results``.
Returns
-------
r : float
The correlation coefficient between -1 and 1.
p : float
The two-side p value.
"""
results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))
references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))
results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results])
references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references])
results_volumes_changes = results_volumes[1:] - results_volumes[:-1]
references_volumes_changes = references_volumes[1:] - references_volumes[:-1]
return pearsonr(results_volumes_changes, references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value)
def obj_assd(result, reference, voxelspacing=None, connectivity=1):
"""
Average symmetric surface distance.
Computes the average symmetric surface distance (ASSD) between the binary objects in
two images.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object as well as when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
assd : float
The average symmetric surface distance between all mutually existing distinct
binary object(s) in ``result`` and ``reference``. The distance unit is the same as for
the spacing of elements along each dimension, which is usually given in mm.
See also
--------
:func:`obj_asd`
Notes
-----
This is a real metric, obtained by calling and averaging
>>> obj_asd(result, reference)
and
>>> obj_asd(reference, result)
The binary images can therefore be supplied in any order.
"""
assd = numpy.mean( (obj_asd(result, reference, voxelspacing, connectivity), obj_asd(reference, result, voxelspacing, connectivity)) )
return assd
def obj_asd(result, reference, voxelspacing=None, connectivity=1):
"""
Average surface distance between objects.
First correspondences between distinct binary objects in reference and result are
established. Then the average surface distance is only computed between corresponding
objects. Correspondence is defined as unique and at least one voxel overlap.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
voxelspacing : float or sequence of floats, optional
The voxelspacing in a distance unit i.e. spacing of elements
along each dimension. If a sequence, must be of length equal to
the input rank; if a single number, this is used for all axes. If
not specified, a grid spacing of unity is implied.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object as well as when determining the surface
of the binary objects. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
asd : float
The average surface distance between all mutually existing distinct binary
object(s) in ``result`` and ``reference``. The distance unit is the same as for the
spacing of elements along each dimension, which is usually given in mm.
See also
--------
:func:`obj_assd`
:func:`obj_tpr`
:func:`obj_fpr`
Notes
-----
This is not a real metric, as it is directed. See `obj_assd` for a real metric of this.
For the understanding of this metric, both the notions of connectedness and surface
distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more
information on the first and :func:`asd` on the second.
Examples
--------
>>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]])
>>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]])
>>> arr1
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
>>> arr2
array([[0, 1, 0],
[0, 1, 0],
[0, 1, 0]])
>>> obj_asd(arr1, arr2)
1.5
>>> obj_asd(arr2, arr1)
0.333333333333
With the `voxelspacing` parameter, the distances between the voxels can be set for
each dimension separately:
>>> obj_asd(arr1, arr2, voxelspacing=(1,2))
1.5
>>> obj_asd(arr2, arr1, voxelspacing=(1,2))
0.333333333333
More examples depicting the notion of object connectedness:
>>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]])
>>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
>>> arr1
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 0]])
>>> arr2
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 1]])
>>> obj_asd(arr1, arr2)
0.0
>>> obj_asd(arr2, arr1)
0.0
>>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]])
>>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])
>>> arr1
array([[1, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> arr2
array([[1, 0, 1],
[1, 0, 0],
[0, 0, 1]])
>>> obj_asd(arr1, arr2)
0.6
>>> obj_asd(arr2, arr1)
0.0
Influence of `connectivity` parameter can be seen in the following example, where
with the (default) connectivity of `1` the first array is considered to contain two
objects, while with an increase connectivity of `2`, just one large object is
detected.
>>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]])
>>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]])
>>> arr1
array([[1, 0, 0],
[0, 1, 1],
[0, 1, 1]])
>>> arr2
array([[1, 0, 0],
[0, 0, 0],
[0, 0, 0]])
>>> obj_asd(arr1, arr2)
0.0
>>> obj_asd(arr1, arr2, connectivity=2)
1.742955328
Note that the connectivity also influence the notion of what is considered an object
surface voxels.
"""
sds = list()
labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity)
slicers1 = find_objects(labelmap1)
slicers2 = find_objects(labelmap2)
for lid2, lid1 in list(mapping.items()):
window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1])
object1 = labelmap1[window] == lid1
object2 = labelmap2[window] == lid2
sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity))
asd = numpy.mean(sds)
return asd
def obj_fpr(result, reference, connectivity=1):
"""
The false positive rate of distinct binary object detection.
The false positive rates gives a percentage measure of how many distinct binary
objects in the second array do not exists in the first array. A partial overlap
(of minimum one voxel) is here considered sufficient.
In cases where two distinct binary object in the second array overlap with a single
distinct object in the first array, only one is considered to have been detected
successfully and the other is added to the count of false positives.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
tpr : float
A percentage measure of how many distinct binary objects in ``results`` have no
corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0`
denotes an ideal score.
Raises
------
RuntimeError
If the second array is empty.
See also
--------
:func:`obj_tpr`
Notes
-----
This is not a real metric, as it is directed. Whatever array is considered as
reference should be passed second. A perfect score of :math:`0` tells that there are no
distinct binary objects in the second array that do not exists also in the reference
array, but does not reveal anything about objects in the reference array also
existing in the second array (use :func:`obj_tpr` for this).
Examples
--------
>>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
>>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
>>> arr2
array([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> arr1
array([[0, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.0
Example of directedness:
>>> arr2 = numpy.asarray([1,0,1,0,1])
>>> arr1 = numpy.asarray([1,0,1,0,0])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.3333333333333333
Examples of multiple overlap treatment:
>>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
>>> obj_fpr(arr1, arr2)
0.3333333333333333
>>> obj_fpr(arr2, arr1)
0.3333333333333333
>>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.3333333333333333
>>> arr2 = numpy.asarray([[1,0,1,0,0],
[1,0,0,0,0],
[1,0,1,1,1],
[0,0,0,0,0],
[1,0,1,0,0]])
>>> arr1 = numpy.asarray([[1,1,1,0,0],
[0,0,0,0,0],
[1,1,1,0,1],
[0,0,0,0,0],
[1,1,1,0,0]])
>>> obj_fpr(arr1, arr2)
0.0
>>> obj_fpr(arr2, arr1)
0.2
"""
_, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
return (n_obj_reference - len(mapping)) / float(n_obj_reference)
def obj_tpr(result, reference, connectivity=1):
"""
The true positive rate of distinct binary object detection.
The true positive rates gives a percentage measure of how many distinct binary
objects in the first array also exists in the second array. A partial overlap
(of minimum one voxel) is here considered sufficient.
In cases where two distinct binary object in the first array overlaps with a single
distinct object in the second array, only one is considered to have been detected
successfully.
Parameters
----------
result : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
reference : array_like
Input data containing objects. Can be any type but will be converted
into binary: background where 0, object everywhere else.
connectivity : int
The neighbourhood/connectivity considered when determining what accounts
for a distinct binary object. This value is passed to
`scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.
The decision on the connectivity is important, as it can influence the results
strongly. If in doubt, leave it as it is.
Returns
-------
tpr : float
A percentage measure of how many distinct binary objects in ``result`` also exists
in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score.
Raises
------
RuntimeError
If the reference object is empty.
See also
--------
:func:`obj_fpr`
Notes
-----
This is not a real metric, as it is directed. Whatever array is considered as
reference should be passed second. A perfect score of :math:`1` tells that all distinct
binary objects in the reference array also exist in the result array, but does not
reveal anything about additional binary objects in the result array
(use :func:`obj_fpr` for this).
Examples
--------
>>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])
>>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])
>>> arr2
array([[1, 0, 0],
[1, 0, 1],
[0, 0, 1]])
>>> arr1
array([[0, 0, 1],
[1, 0, 1],
[0, 0, 1]])
>>> obj_tpr(arr1, arr2)
1.0
>>> obj_tpr(arr2, arr1)
1.0
Example of directedness:
>>> arr2 = numpy.asarray([1,0,1,0,1])
>>> arr1 = numpy.asarray([1,0,1,0,0])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
1.0
Examples of multiple overlap treatment:
>>> arr2 = numpy.asarray([1,0,1,0,1,1,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,0,1])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
0.6666666666666666
>>> arr2 = numpy.asarray([1,0,1,1,1,0,1])
>>> arr1 = numpy.asarray([1,1,1,0,1,1,1])
>>> obj_tpr(arr1, arr2)
0.6666666666666666
>>> obj_tpr(arr2, arr1)
1.0
>>> arr2 = numpy.asarray([[1,0,1,0,0],
[1,0,0,0,0],
[1,0,1,1,1],
[0,0,0,0,0],
[1,0,1,0,0]])
>>> arr1 = numpy.asarray([[1,1,1,0,0],
[0,0,0,0,0],
[1,1,1,0,1],
[0,0,0,0,0],
[1,1,1,0,0]])
>>> obj_tpr(arr1, arr2)
0.8
>>> obj_tpr(arr2, arr1)
1.0
"""
_, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)
return len(mapping) / float(n_obj_result)
def __distinct_binary_object_correspondences(reference, result, connectivity=1):
"""
Determines all distinct (where connectivity is defined by the connectivity parameter
passed to scipy's `generate_binary_structure`) binary objects in both of the input
parameters and returns a 1to1 mapping from the labelled objects in reference to the
corresponding (whereas a one-voxel overlap suffices for correspondence) objects in
result.
All stems from the problem, that the relationship is non-surjective many-to-many.
@return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1)
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
# binary structure
footprint = generate_binary_structure(result.ndim, connectivity)
# label distinct binary objects
labelmap1, n_obj_result = label(result, footprint)
labelmap2, n_obj_reference = label(reference, footprint)
# find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing
slicers = find_objects(labelmap2) # get windows of labelled objects
mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1
used_labels = set() # set to collect all already used labels from labelmap2
one_to_many = list() # list to collect all one-to-many mappings
for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows
l1id += 1 # labelled objects have ids sarting from 1
bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation
l2ids = numpy.unique(labelmap1[slicer][bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping)
l2ids = l2ids[0 != l2ids] # remove background identifiers (=0)
if 1 == len(l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used
l2id = l2ids[0]
if not l2id in used_labels:
mapping[l1id] = l2id
used_labels.add(l2id)
elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing
one_to_many.append((l1id, set(l2ids)))
# process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first
while True:
one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in one_to_many] # remove already used ids from all sets
one_to_many = [x for x in one_to_many if x[1]] # remove empty sets
one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length
if 0 == len(one_to_many):
break
l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set
mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings
used_labels.add(l2id) # mark target label as used
one_to_many = one_to_many[1:] # delete the processed set from all sets
return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping
def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
"""
The distances between the surface voxel of binary objects in result and their
nearest partner surface voxel of a binary object in reference.
"""
result = numpy.atleast_1d(result.astype(numpy.bool))
reference = numpy.atleast_1d(reference.astype(numpy.bool))
if voxelspacing is not None:
voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64)
if not voxelspacing.flags.contiguous:
voxelspacing = voxelspacing.copy()
# binary structure
footprint = generate_binary_structure(result.ndim, connectivity)
# test for emptiness
if 0 == numpy.count_nonzero(result):
raise RuntimeError('The first supplied array does not contain any binary object.')
if 0 == numpy.count_nonzero(reference):
raise RuntimeError('The second supplied array does not contain any binary object.')
# extract only 1-pixel border line of objects
result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
# compute average surface distance
# Note: scipys distance transform is calculated only inside the borders of the
# foreground objects, therefore the input has to be reversed
dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
sds = dt[result_border]
return sds
def __combine_windows(w1, w2):
"""
Joins two windows (defined by tuple of slices) such that their maximum
combined extend is covered by the new returned window.
"""
res = []
for s1, s2 in zip(w1, w2):
res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop)))
return tuple(res)
================================================
FILE: utils/dice_loss.py
================================================
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
class SoftDiceLoss(_Loss):
'''
Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
eps is a small constant to avoid zero division,
'''
def __init__(self, *args, **kwargs):
super(SoftDiceLoss, self).__init__()
def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):
dice_loss = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)
return dice_loss
def get_soft_label(input_tensor, num_class):
"""
convert a label tensor to soft label
input_tensor: tensor with shape [N, C, H, W]
output_tensor: shape [N, H, W, num_class]
"""
tensor_list = []
input_tensor = input_tensor.permute(0, 2, 3, 1)
for i in range(num_class):
temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=-1)
output_tensor = output_tensor.float()
return output_tensor
def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):
predict = prediction.permute(0, 2, 3, 1)
pred = predict.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
n_voxels = ground.size(0)
if weight_map is not None:
weight_map = weight_map.view(-1)
weight_map_nclass = weight_map.repeat(num_class).view_as(pred)
ref_vol = torch.sum(weight_map_nclass * ground, 0)
intersect = torch.sum(weight_map_nclass * ground * pred, 0)
seg_vol = torch.sum(weight_map_nclass * pred, 0)
else:
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)
# dice_loss = 1.0 - torch.mean(dice_score)
# return dice_loss
dice_score = torch.mean(-torch.log(dice_score))
return dice_score
def val_dice_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_mean_score = torch.mean(dice_score)
placenta_dice = dice_score[1]
brain_dice = dice_score[2]
return placenta_dice, brain_dice
def Intersection_over_Union_fetus(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
dice_mean_score = torch.mean(iou_score)
placenta_iou = iou_score[1]
brain_iou = iou_score[2]
return placenta_iou, brain_iou
def val_dice_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)
dice_mean_score = torch.mean(dice_score)
return dice_mean_score
def Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):
# predict = prediction.permute(0, 2, 3, 1)
pred = prediction.contiguous().view(-1, num_class)
# pred = F.softmax(pred, dim=1)
ground = soft_ground_truth.view(-1, num_class)
ref_vol = torch.sum(ground, 0)
intersect = torch.sum(ground * pred, 0)
seg_vol = torch.sum(pred, 0)
iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)
iou_mean_score = torch.mean(iou_score)
return iou_mean_score
================================================
FILE: utils/evaluation.py
================================================
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
================================================
FILE: utils/transform.py
================================================
import torch
import random
import PIL
import numbers
import numpy as np
import torch.nn as nn
import collections
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
def ISIC2018_transform(sample, train_type):
image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
Image.fromarray(np.uint8(sample['label']), mode='L')
if train_type == 'train':
image, label = randomcrop(size=(224, 300))(image, label)
image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
else:
image, label = resize(size=(224, 300))(image, label)
image = ts.Compose([ts.ToTensor(),
ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
label = ts.ToTensor()(label)
return {'image': image, 'label': label}
# these are founctional function for transform
def randomflip_rotate(img, lab, p=0.5, degrees=0):
if random.random() < p:
img = TF.hflip(img)
lab = TF.hflip(lab)
if random.random() < p:
img = TF.vflip(img)
lab = TF.vflip(lab)
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
degrees = degrees
angle = random.uniform(degrees[0], degrees[1])
img = TF.rotate(img, angle)
lab = TF.rotate(lab, angle)
return img, lab
class randomcrop(object):
"""Crop the given PIL Image and mask at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be cropped.
lab (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image and mask.
"""
if self.padding > 0:
img = TF.pad(img, self.padding)
lab = TF.pad(lab, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size)
return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class resize(object):
"""Resize the input PIL Image and mask to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img, lab):
"""
Args:
img (PIL Image): Image to be scaled.
lab (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image and mask.
"""
return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
def itensity_normalize(volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized n d volume
"""
# pixels = volume[volume > 0]
mean = volume.mean()
std = volume.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
================================================
FILE: validation.py
================================================
import os
import torch
import argparse
import numpy as np
import pandas as pd
import torch.utils.data as Data
from utils.binary import assd
from distutils.version import LooseVersion
from Datasets.ISIC2018 import ISIC2018_dataset
from utils.transform import ISIC2018_transform
from Models.networks.network import Comprehensive_Atten_Unet
from utils.dice_loss import get_soft_label, val_dice_isic
from utils.dice_loss import Intersection_over_Union_isic
from time import *
Test_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}
Test_Dataset = {'ISIC2018': ISIC2018_dataset}
Test_Transform = {'ISIC2018': ISIC2018_transform}
def test_isic(test_loader, model):
isic_dice = []
isic_iou = []
isic_assd = []
infer_time = []
model.eval()
for step, (img, lab) in enumerate(test_loader):
image = img.float().cuda()
target = lab.float().cuda()
# output, atten2_map, atten3_map = model(image) # model output
begin_time = time()
output = model(image)
end_time = time()
pred_time = end_time - begin_time
infer_time.append(pred_time)
output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)
output_soft = get_soft_label(output_dis, args.num_classes)
target_soft = get_soft_label(target, args.num_classes) # get soft label
# input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)
label_arr = target_soft.cpu().numpy().astype(np.uint8)
# label_shw = np.squeeze(target.cpu().numpy()).astype(np.uint8)
output_arr = output_soft.cpu().byte().numpy().astype(np.uint8)
isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes) # the dice accuracy
isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes) # the iou accuracy
isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])
dice_np = isic_b_dice.data.cpu().numpy()
iou_np = isic_b_iou.data.cpu().numpy()
isic_dice.append(dice_np)
isic_iou.append(iou_np)
isic_assd.append(isic_b_asd)
# df = pd.DataFrame(data=dice_np)
# df.to_csv(args.ckpt + '/refine_result.csv')
isic_dice_mean = np.average(isic_dice)
isic_dice_std = np.std(isic_dice)
isic_iou_mean = np.average(isic_iou)
isic_iou_std = np.std(isic_iou)
isic_assd_mean = np.average(isic_assd)
isic_assd_std = np.std(isic_assd)
all_time = np.sum(infer_time)
print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The ISIC Accuracy std: {isic_dice_std: .4f}'.format(
isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))
print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(
isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))
print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(
isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))
print('The inference time: {time: .4f}'.format(time=all_time))
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), 'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(description='U-net add Attention mechanism for biomedical Dataset')
# Model related arguments
parser.add_argument('--id', default='Comp_Atten_Unet',
help='a name for identitying the model. Choose from the following options: Unet_fetus')
# Path related arguments
parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',
help='root directory of data')
parser.add_argument('--ckpt', default='./saved_models',
help='folder to output checkpoints')
parser.add_argument('--save', default='./result',
help='folder to outoput result')
parser.add_argument('--batch_size', type=int, default=1, metavar='N',
help='input batch size for training (default: 16)')
parser.add_argument('--num_classes', default=2, type=int,
help='number of classes')
parser.add_argument('--num_input', default=3, type=int,
help='number of input image for each patient')
parser.add_argument('--epoch', type=int, default=300, metavar='N',
help='choose the specific epoch checkpoints')
# other arguments
parser.add_argument('--data', default='ISIC2018', help='choose the dataset')
parser.add_argument('--out_size', default=(224, 300), help='the output image size')
parser.add_argument('--att_pos', default='dec', type=str,
help='where attention to plug in (enc, dec, enc\&dec)')
parser.add_argument('--view', default='axial', type=str,
help='use what views data to test (for fetal MRI)')
parser.add_argument('--val_folder', default='folder0', type=str,
help='which cross validation folder')
args = parser.parse_args()
args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)
# loading the dataset
print('loading the {0} dataset ...'.format('test'))
testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test', transform=Test_Transform[args.data])
testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False)
print('Loading is done\n')
# Define model
if torch.cuda.is_available():
print('We can use', torch.cuda.device_count(), 'GPUs to train the network')
if args.data == 'Fetus':
args.num_input = 1
args.num_classes = 3
args.out_size = (256, 256)
elif args.data == 'ISIC2018':
args.num_input = 3
args.num_classes = 2
args.out_size = (224, 300)
model = Test_Model[args.id](args, args.num_input, args.num_classes).cuda()
# model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
# Load the trained best model
modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'
if os.path.isfile(modelname):
print("=> Loading checkpoint '{}'".format(modelname))
checkpoint = torch.load(modelname)
# start_epoch = checkpoint['epoch']
# multi-GPU transfer to one GPU
# model_dict = model.state_dict()
# pretrained_dict = checkpoint['state_dict']
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in pretrained_dict.items():
# name = k[7:]
# new_state_dict[name] = v
#
# model_dict.update(new_state_dict)
# model.load_state_dict(model_dict)
model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['opt_dict'])
print("=> Loaded saved the best model at (epoch {})".format(checkpoint['epoch']))
else:
print("=> No checkpoint found at '{}'".format(modelname))
test_isic(testloader, model)
gitextract_zlon6qmm/ ├── .idea/ │ ├── CA-Net.iml │ ├── encodings.xml │ ├── misc.xml │ ├── modules.xml │ ├── vcs.xml │ └── workspace.xml ├── Datasets/ │ ├── ISIC2018.py │ └── folder0/ │ └── folder0_test.list ├── Models/ │ ├── __init__.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── channel_attention_layer.py │ │ ├── grid_attention_layer.py │ │ ├── modules.py │ │ ├── nonlocal_layer.py │ │ └── scale_attention_layer.py │ ├── networks/ │ │ └── network.py │ └── networks_other.py ├── README.md ├── create_folder.py ├── data/ │ └── ISIC2018_Task1_npy_all/ │ ├── image/ │ │ └── ISIC_0010854.npy │ └── label/ │ └── ISIC_0010854_segmentation.npy ├── isic_preprocess.py ├── main.py ├── result/ │ └── atten_map/ │ └── 25_2_8_wgt ├── show_fused_heatmap.py ├── utils/ │ ├── binary.py │ ├── dice_loss.py │ ├── evaluation.py │ └── transform.py └── validation.py
SYMBOL INDEX (182 symbols across 16 files)
FILE: Datasets/ISIC2018.py
class ISIC2018_dataset (line 15) | class ISIC2018_dataset(Dataset):
method __init__ (line 16) | def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',
method __getitem__ (line 39) | def __getitem__(self, item: int):
method __len__ (line 51) | def __len__(self):
FILE: Models/layers/channel_attention_layer.py
function conv3x3 (line 5) | def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
class SE_Conv_Block (line 10) | class SE_Conv_Block(nn.Module):
method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_o...
method forward (line 51) | def forward(self, x):
FILE: Models/layers/grid_attention_layer.py
class _GridAttentionBlockND (line 7) | class _GridAttentionBlockND(nn.Module):
method __init__ (line 8) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
method forward (line 74) | def forward(self, x, g):
method _concatenation (line 84) | def _concatenation(self, x, g):
method _concatenation_debug (line 109) | def _concatenation_debug(self, x, g):
method _concatenation_residual (line 135) | def _concatenation_residual(self, x, g):
class GridAttentionBlock2D (line 162) | class GridAttentionBlock2D(_GridAttentionBlockND):
method __init__ (line 163) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
class GridAttentionBlock3D (line 173) | class GridAttentionBlock3D(_GridAttentionBlockND):
method __init__ (line 174) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
class _GridAttentionBlockND_TORR (line 183) | class _GridAttentionBlockND_TORR(nn.Module):
method __init__ (line 184) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
method forward (line 284) | def forward(self, x, g):
method _concatenation (line 294) | def _concatenation(self, x, g):
class GridAttentionBlock2D_TORR (line 359) | class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR):
method __init__ (line 360) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
class GridAttentionBlock3D_TORR (line 377) | class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR):
method __init__ (line 378) | def __init__(self, in_channels, gating_channels, inter_channels=None, ...
class MultiAttentionBlock (line 388) | class MultiAttentionBlock(nn.Module):
method __init__ (line 389) | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_...
method forward (line 406) | def forward(self, input, gating_signal):
FILE: Models/layers/modules.py
function conv1x1 (line 5) | def conv1x1(in_planes, out_planes, stride=1, bias=False):
function conv3x3 (line 11) | def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
class conv_block (line 17) | class conv_block(nn.Module):
method __init__ (line 18) | def __init__(self, ch_in, ch_out, drop_out=False):
method forward (line 30) | def forward(self, x):
class UpCat (line 38) | class UpCat(nn.Module):
method __init__ (line 39) | def __init__(self, in_feat, out_feat, is_deconv=True):
method forward (line 47) | def forward(self, inputs, down_outputs):
class UpCatconv (line 64) | class UpCatconv(nn.Module):
method __init__ (line 65) | def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False):
method forward (line 75) | def forward(self, inputs, down_outputs):
class UnetGridGatingSignal3 (line 92) | class UnetGridGatingSignal3(nn.Module):
method __init__ (line 93) | def __init__(self, in_size, out_size, kernel_size=(1, 1), is_batchnorm...
method forward (line 106) | def forward(self, inputs):
class UnetDsv3 (line 111) | class UnetDsv3(nn.Module):
method __init__ (line 112) | def __init__(self, in_size, out_size, scale_factor):
method forward (line 117) | def forward(self, input):
FILE: Models/layers/nonlocal_layer.py
class _NonLocalBlockND (line 7) | class _NonLocalBlockND(nn.Module):
method __init__ (line 8) | def __init__(self, in_channels, inter_channels=None, dimension=3, mode...
method forward (line 103) | def forward(self, x):
method _embedded_gaussian (line 112) | def _embedded_gaussian(self, x):
method _gaussian (line 137) | def _gaussian(self, x):
method _dot_product (line 161) | def _dot_product(self, x):
method _concatenation (line 182) | def _concatenation(self, x):
method _concatenation_proper (line 213) | def _concatenation_proper(self, x):
method _concatenation_proper_down (line 246) | def _concatenation_proper_down(self, x):
class NONLocalBlock1D (line 287) | class NONLocalBlock1D(_NonLocalBlockND):
method __init__ (line 288) | def __init__(self, in_channels, inter_channels=None, mode='embedded_ga...
class NONLocalBlock2D (line 296) | class NONLocalBlock2D(_NonLocalBlockND):
method __init__ (line 297) | def __init__(self, in_channels, inter_channels=None, mode='embedded_ga...
class NONLocalBlock3D (line 305) | class NONLocalBlock3D(_NonLocalBlockND):
method __init__ (line 306) | def __init__(self, in_channels, inter_channels=None, mode='embedded_ga...
FILE: Models/layers/scale_attention_layer.py
function conv1x1 (line 6) | def conv1x1(in_planes, out_planes, stride=1, bias=False):
function conv3x3 (line 13) | def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):
class BasicConv (line 19) | class BasicConv(nn.Module):
method __init__ (line 20) | def __init__(self, in_planes, out_planes, kernel_size, stride=1, paddi...
method forward (line 29) | def forward(self, x):
class Flatten (line 38) | class Flatten(nn.Module):
method forward (line 39) | def forward(self, x):
class ChannelGate (line 43) | class ChannelGate(nn.Module):
method __init__ (line 44) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 55) | def forward(self, x):
function logsumexp_2d (line 86) | def logsumexp_2d(tensor):
class ChannelPool (line 93) | class ChannelPool(nn.Module):
method forward (line 94) | def forward(self, x):
class SpatialGate (line 98) | class SpatialGate(nn.Module):
method __init__ (line 99) | def __init__(self):
method forward (line 105) | def forward(self, x):
class SpatialAtten (line 113) | class SpatialAtten(nn.Module):
method __init__ (line 114) | def __init__(self, in_size, out_size, kernel_size=3, stride=1):
method forward (line 121) | def forward(self, x):
class Scale_atten_block (line 134) | class Scale_atten_block(nn.Module):
method __init__ (line 135) | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg...
method forward (line 142) | def forward(self, x):
class scale_atten_convblock (line 150) | class scale_atten_convblock(nn.Module):
method __init__ (line 151) | def __init__(self, in_size, out_size, stride=1, downsample=None, use_c...
method forward (line 173) | def forward(self, x):
FILE: Models/networks/network.py
class Comprehensive_Atten_Unet (line 11) | class Comprehensive_Atten_Unet(nn.Module):
method __init__ (line 12) | def __init__(self, args, in_ch=3, n_classes=2, feature_scale=4, is_dec...
method forward (line 70) | def forward(self, inputs):
FILE: Models/networks_other.py
function weights_init_normal (line 14) | def weights_init_normal(m):
function weights_init_xavier (line 26) | def weights_init_xavier(m):
function weights_init_kaiming (line 38) | def weights_init_kaiming(m):
function weights_init_orthogonal (line 50) | def weights_init_orthogonal(m):
function init_weights (line 62) | def init_weights(net, init_type='normal'):
function get_norm_layer (line 76) | def get_norm_layer(norm_type='instance'):
function adjust_learning_rate (line 88) | def adjust_learning_rate(optimizer, lr):
function get_scheduler (line 93) | def get_scheduler(optimizer, opt):
function define_G (line 141) | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', u...
function define_D (line 165) | def define_D(input_nc, ndf, which_model_netD,
function print_network (line 186) | def print_network(net):
function get_n_parameters (line 194) | def get_n_parameters(net):
function measure_fp_bp_time (line 201) | def measure_fp_bp_time(model, x, y):
function benchmark_fp_bp_time (line 224) | def benchmark_fp_bp_time(model, x, y, n_trial=1000):
class GANLoss (line 258) | class GANLoss(nn.Module):
method __init__ (line 259) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
method get_target_tensor (line 272) | def get_target_tensor(self, input, target_is_real):
method __call__ (line 290) | def __call__(self, input, target_is_real):
class ResnetGenerator (line 299) | class ResnetGenerator(nn.Module):
method __init__ (line 300) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 344) | def forward(self, input):
class ResnetBlock (line 352) | class ResnetBlock(nn.Module):
method __init__ (line 353) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
method build_conv_block (line 357) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
method forward (line 389) | def forward(self, x):
class UnetGenerator (line 398) | class UnetGenerator(nn.Module):
method __init__ (line 399) | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
method forward (line 415) | def forward(self, input):
class UnetSkipConnectionBlock (line 425) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 426) | def __init__(self, outer_nc, inner_nc, input_nc=None,
method forward (line 471) | def forward(self, x):
class NLayerDiscriminator (line 479) | class NLayerDiscriminator(nn.Module):
method __init__ (line 480) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
method forward (line 523) | def forward(self, input):
FILE: create_folder.py
function create_5_floder (line 9) | def create_5_floder(folder, save_foler):
function text_save (line 35) | def text_save(filename, data): # filename: path to write CSV, data:...
FILE: main.py
function train (line 35) | def train(train_loader, model, criterion, optimizer, args, epoch):
function valid_fetus (line 63) | def valid_fetus(valid_loader, model, criterion, optimizer, args, epoch, ...
function valid_isic (line 105) | def valid_isic(valid_loader, model, criterion, optimizer, args, epoch, m...
function test_fetus (line 145) | def test_fetus(test_loader, model, args):
function test_isic (line 223) | def test_isic(test_loader, model, args):
function main (line 278) | def main(args):
FILE: show_fused_heatmap.py
function map_scalar_to_color (line 8) | def map_scalar_to_color(x):
function get_fused_heat_map (line 27) | def get_fused_heat_map(image, att):
FILE: utils/binary.py
function dc (line 34) | def dc(result, reference):
function jc (line 83) | def jc(result, reference):
function precision (line 118) | def precision(result, reference):
function recall (line 165) | def recall(result, reference):
function sensitivity (line 212) | def sensitivity(result, reference):
function specificity (line 223) | def specificity(result, reference):
function true_negative_rate (line 270) | def true_negative_rate(result, reference):
function true_positive_rate (line 282) | def true_positive_rate(result, reference):
function positive_predictive_value (line 294) | def positive_predictive_value(result, reference):
function hd (line 306) | def hd(result, reference, voxelspacing=None, connectivity=1):
function hd95 (line 354) | def hd95(result, reference, voxelspacing=None, connectivity=1):
function assd (line 402) | def assd(result, reference, voxelspacing=None, connectivity=1):
function asd (line 456) | def asd(result, reference, voxelspacing=None, connectivity=1):
function ravd (line 565) | def ravd(result, reference):
function volume_correlation (line 652) | def volume_correlation(results, references):
function volume_change_correlation (line 686) | def volume_change_correlation(results, references):
function obj_assd (line 723) | def obj_assd(result, reference, voxelspacing=None, connectivity=1):
function obj_asd (line 778) | def obj_asd(result, reference, voxelspacing=None, connectivity=1):
function obj_fpr (line 920) | def obj_fpr(result, reference, connectivity=1):
function obj_tpr (line 1031) | def obj_tpr(result, reference, connectivity=1):
function __distinct_binary_object_correspondences (line 1141) | def __distinct_binary_object_correspondences(reference, result, connecti...
function __surface_distances (line 1195) | def __surface_distances(result, reference, voxelspacing=None, connectivi...
function __combine_windows (line 1229) | def __combine_windows(w1, w2):
FILE: utils/dice_loss.py
class SoftDiceLoss (line 6) | class SoftDiceLoss(_Loss):
method __init__ (line 11) | def __init__(self, *args, **kwargs):
method forward (line 14) | def forward(self, prediction, soft_ground_truth, num_class=3, weight_m...
function get_soft_label (line 19) | def get_soft_label(input_tensor, num_class):
function soft_dice_loss (line 35) | def soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=...
function val_dice_fetus (line 58) | def val_dice_fetus(prediction, soft_ground_truth, num_class):
function Intersection_over_Union_fetus (line 74) | def Intersection_over_Union_fetus(prediction, soft_ground_truth, num_cla...
function val_dice_isic (line 90) | def val_dice_isic(prediction, soft_ground_truth, num_class):
function Intersection_over_Union_isic (line 104) | def Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):
FILE: utils/evaluation.py
class AverageMeter (line 1) | class AverageMeter(object):
method __init__ (line 4) | def __init__(self):
method reset (line 7) | def reset(self):
method update (line 13) | def update(self, val, n=1):
FILE: utils/transform.py
function ISIC2018_transform (line 22) | def ISIC2018_transform(sample, train_type):
function randomflip_rotate (line 40) | def randomflip_rotate(img, lab, p=0.5, degrees=0):
class randomcrop (line 63) | class randomcrop(object):
method __init__ (line 78) | def __init__(self, size, padding=0, pad_if_needed=False):
method get_params (line 87) | def get_params(img, output_size):
method __call__ (line 106) | def __call__(self, img, lab):
method __repr__ (line 132) | def __repr__(self):
class resize (line 136) | class resize(object):
method __init__ (line 149) | def __init__(self, size, interpolation=Image.BILINEAR):
method __call__ (line 154) | def __call__(self, img, lab):
method __repr__ (line 165) | def __repr__(self):
function itensity_normalize (line 170) | def itensity_normalize(volume):
FILE: validation.py
function test_isic (line 27) | def test_isic(test_loader, model):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (197K chars).
[
{
"path": ".idea/CA-Net.iml",
"chars": 398,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n <component name=\"NewModuleRootManager"
},
{
"path": ".idea/encodings.xml",
"chars": 135,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"Encoding\" addBOMForNewFiles=\"with NO BOM"
},
{
"path": ".idea/misc.xml",
"chars": 298,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"JavaScriptSettings\">\n <option name=\"l"
},
{
"path": ".idea/modules.xml",
"chars": 264,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ProjectModuleManager\">\n <modules>\n "
},
{
"path": ".idea/vcs.xml",
"chars": 180,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"VcsDirectoryMappings\">\n <mapping dire"
},
{
"path": ".idea/workspace.xml",
"chars": 19193,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ChangeListManager\">\n <list default=\"t"
},
{
"path": "Datasets/ISIC2018.py",
"chars": 2173,
"preview": "import os\nimport PIL\nimport torch\nimport numpy as np\nimport nibabel as nib\nimport matplotlib.pyplot as plt\n\nfrom os impo"
},
{
"path": "Datasets/folder0/folder0_test.list",
"chars": 16,
"preview": "ISIC_0010854.npy"
},
{
"path": "Models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Models/layers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "Models/layers/channel_attention_layer.py",
"chars": 3598,
"preview": "import torch.nn as nn\n\n\n# # SE block add to U-net\ndef conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):\n "
},
{
"path": "Models/layers/grid_attention_layer.py",
"chars": 18095,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom Models.networks_other import init_weights\n\n\n"
},
{
"path": "Models/layers/modules.py",
"chars": 4618,
"preview": "import torch\nimport torch.nn as nn\n\n\ndef conv1x1(in_planes, out_planes, stride=1, bias=False):\n \"1x1 convolution\"\n "
},
{
"path": "Models/layers/nonlocal_layer.py",
"chars": 13546,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom Models.networks_other import init_weights\n\n\n"
},
{
"path": "Models/layers/scale_attention_layer.py",
"chars": 7694,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\ndef conv1x1(in_planes, out_planes, stride=1, b"
},
{
"path": "Models/networks/network.py",
"chars": 5720,
"preview": "import torch\nimport torch.nn as nn\n\nfrom Models.layers.modules import conv_block, UpCat, UpCatconv, UnetDsv3, UnetGridGa"
},
{
"path": "Models/networks_other.py",
"chars": 20196,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.autograd import Variable\nfrom t"
},
{
"path": "README.md",
"chars": 2934,
"preview": "## CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation\nThis reposit"
},
{
"path": "create_folder.py",
"chars": 1565,
"preview": "import os\nimport numpy\nfrom random import shuffle\n\nPATH = './data/ISIC2018_Task1_npy_all/image'\nSAVE_PATH = './Datasets'"
},
{
"path": "isic_preprocess.py",
"chars": 1683,
"preview": "#!/usr/bin/python3\n# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection\n# -*- coding: utf-8 -*"
},
{
"path": "main.py",
"chars": 22068,
"preview": "#!/usr/bin/python3\n# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection\n# -*- coding: utf-8 -*"
},
{
"path": "show_fused_heatmap.py",
"chars": 2026,
"preview": "import os\nimport cv2\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom PIL import Image\n\n\ndef map_scalar_to_color("
},
{
"path": "utils/binary.py",
"chars": 44737,
"preview": "# Copyright (C) 2013 Oskar Maier\n# \n# This program is free software: you can redistribute it and/or modify\n# it under th"
},
{
"path": "utils/dice_loss.py",
"chars": 4243,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\n\nclass SoftDiceLoss(_Loss):\n ''"
},
{
"path": "utils/evaluation.py",
"chars": 391,
"preview": "class AverageMeter(object):\n \"\"\"Computes and stores the average and current value\"\"\"\n\n def __init__(self):\n "
},
{
"path": "utils/transform.py",
"chars": 6433,
"preview": "import torch\nimport random\nimport PIL\nimport numbers\nimport numpy as np\nimport torch.nn as nn\nimport collections\nimport "
},
{
"path": "validation.py",
"chars": 7129,
"preview": "import os\nimport torch\nimport argparse\nimport numpy as np\nimport pandas as pd\nimport torch.utils.data as Data\nfrom utils"
}
]
// ... and 3 more files (download for full content)
About this extraction
This page contains the full source code of the HiLab-git/CA-Net GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (184.9 KB), approximately 49.4k tokens, and a symbol index with 182 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.