[
  {
    "path": ".idea/CA-Net.iml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager\">\n    <content url=\"file://$MODULE_DIR$\" />\n    <orderEntry type=\"inheritedJdk\" />\n    <orderEntry type=\"sourceFolder\" forTests=\"false\" />\n  </component>\n  <component name=\"TestRunnerService\">\n    <option name=\"PROJECT_TEST_RUNNER\" value=\"Unittests\" />\n  </component>\n</module>"
  },
  {
    "path": ".idea/encodings.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"Encoding\" addBOMForNewFiles=\"with NO BOM\" />\n</project>"
  },
  {
    "path": ".idea/misc.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"JavaScriptSettings\">\n    <option name=\"languageLevel\" value=\"ES6\" />\n  </component>\n  <component name=\"ProjectRootManager\" version=\"2\" project-jdk-name=\"Python 3.6 (pytorch)\" project-jdk-type=\"Python SDK\" />\n</project>"
  },
  {
    "path": ".idea/modules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n      <module fileurl=\"file://$PROJECT_DIR$/.idea/CA-Net.iml\" filepath=\"$PROJECT_DIR$/.idea/CA-Net.iml\" />\n    </modules>\n  </component>\n</project>"
  },
  {
    "path": ".idea/vcs.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping directory=\"$PROJECT_DIR$\" vcs=\"Git\" />\n  </component>\n</project>"
  },
  {
    "path": ".idea/workspace.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"true\" id=\"2541e3bb-fbe2-4fc6-be8f-b6401cb16713\" name=\"Default Changelist\" comment=\"\">\n      <change beforePath=\"$PROJECT_DIR$/.idea/CA-Net.iml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/CA-Net.iml\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/.idea/workspace.xml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/workspace.xml\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/create_folder.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/create_folder.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/image/ISIC_0010854.npy\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/image/ISIC_0010854.npy\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/label/ISIC_0010854_segmentation.npy\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/data/ISIC2018_Task1_npy_all/label/ISIC_0010854_segmentation.npy\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/isic_preprocess.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/isic_preprocess.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/utils/transform.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/utils/transform.py\" afterDir=\"false\" />\n    </list>\n    <option name=\"EXCLUDED_CONVERTED_TO_IGNORED\" value=\"true\" />\n    <option name=\"SHOW_DIALOG\" value=\"false\" />\n    <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n    <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n    <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n  </component>\n  <component name=\"CoverageDataManager\">\n    <SUITE FILE_PATH=\"coverage/CA_Net$validation.coverage\" NAME=\"validation Coverage Results\" MODIFIED=\"1598537010616\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n    <SUITE FILE_PATH=\"coverage/CA_Net$isic_preprocess.coverage\" NAME=\"isic_preprocess Coverage Results\" MODIFIED=\"1598536798821\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n  </component>\n  <component name=\"FileEditorManager\">\n    <leaf SIDE_TABS_SIZE_LIMIT_KEY=\"300\">\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/create_folder.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"374\">\n              <caret line=\"17\" column=\"46\" lean-forward=\"true\" selection-start-line=\"17\" selection-start-column=\"46\" selection-end-line=\"17\" selection-end-column=\"46\" />\n              <folding>\n                <element signature=\"e#0#9#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/README.md\">\n          <provider selected=\"true\" editor-type-id=\"split-provider[text-editor;markdown-preview-editor]\">\n            <state split_layout=\"SPLIT\">\n              <first_editor relative-caret-position=\"1254\">\n                <caret line=\"57\" selection-start-line=\"57\" selection-end-line=\"57\" />\n              </first_editor>\n              <second_editor />\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/validation.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"743\">\n              <caret line=\"118\" selection-start-line=\"118\" selection-end-line=\"118\" />\n              <folding>\n                <element signature=\"e#0#9#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"true\">\n        <entry file=\"file://$PROJECT_DIR$/isic_preprocess.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"550\">\n              <caret line=\"25\" column=\"38\" lean-forward=\"true\" selection-start-line=\"25\" selection-start-column=\"38\" selection-end-line=\"25\" selection-end-column=\"38\" />\n              <folding>\n                <element signature=\"e#143#152#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/utils/transform.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"479\">\n              <caret line=\"85\" column=\"17\" lean-forward=\"true\" selection-start-line=\"85\" selection-start-column=\"17\" selection-end-line=\"85\" selection-end-column=\"17\" />\n              <folding>\n                <element signature=\"e#0#12#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/Models/networks/network.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"374\">\n              <caret line=\"17\" column=\"36\" selection-start-line=\"17\" selection-start-column=\"36\" selection-end-line=\"17\" selection-end-column=\"36\" />\n              <folding>\n                <element signature=\"e#0#12#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/Models/layers/nonlocal_layer.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state>\n              <folding>\n                <element signature=\"e#0#12#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/Models/layers/scale_attention_layer.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"2640\">\n              <caret line=\"120\" column=\"25\" selection-start-line=\"120\" selection-start-column=\"25\" selection-end-line=\"120\" selection-end-column=\"25\" />\n              <folding>\n                <element signature=\"e#0#12#0\" expanded=\"true\" />\n              </folding>\n            </state>\n          </provider>\n        </entry>\n      </file>\n      <file pinned=\"false\" current-in-tab=\"false\">\n        <entry file=\"file://$PROJECT_DIR$/main.py\">\n          <provider selected=\"true\" editor-type-id=\"text-editor\">\n            <state relative-caret-position=\"545\">\n              <caret line=\"428\" column=\"36\" selection-start-line=\"428\" selection-start-column=\"36\" selection-end-line=\"428\" selection-end-column=\"36\" />\n            </state>\n          </provider>\n        </entry>\n      </file>\n    </leaf>\n  </component>\n  <component name=\"Git.Settings\">\n    <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n  </component>\n  <component name=\"IdeDocumentHistory\">\n    <option name=\"CHANGED_PATHS\">\n      <list>\n        <option value=\"$PROJECT_DIR$/Datasets/ISIC2018.py\" />\n        <option value=\"$PROJECT_DIR$/main.py\" />\n        <option value=\"$PROJECT_DIR$/README.md\" />\n        <option value=\"$PROJECT_DIR$/validation.py\" />\n        <option value=\"$PROJECT_DIR$/Models/networks/network.py\" />\n        <option value=\"$PROJECT_DIR$/utils/transform.py\" />\n        <option value=\"$PROJECT_DIR$/isic_preprocess.py\" />\n        <option value=\"$PROJECT_DIR$/create_folder.py\" />\n      </list>\n    </option>\n  </component>\n  <component name=\"ProjectFrameBounds\" extendedState=\"6\">\n    <option name=\"x\" value=\"65\" />\n    <option name=\"y\" value=\"-4\" />\n    <option name=\"width\" value=\"1855\" />\n    <option name=\"height\" value=\"1084\" />\n  </component>\n  <component name=\"ProjectView\">\n    <navigator proportions=\"\" version=\"1\">\n      <foldersAlwaysOnTop value=\"true\" />\n    </navigator>\n    <panes>\n      <pane id=\"Scope\" />\n      <pane id=\"ProjectPane\">\n        <subPane>\n          <expand>\n            <path>\n              <item name=\"CA-Net\" type=\"b2602c69:ProjectViewProjectNode\" />\n              <item name=\"CA-Net\" type=\"462c0819:PsiDirectoryNode\" />\n            </path>\n          </expand>\n          <select />\n        </subPane>\n      </pane>\n    </panes>\n  </component>\n  <component name=\"PropertiesComponent\">\n    <property name=\"WebServerToolWindowFactoryState\" value=\"false\" />\n    <property name=\"last_opened_file_path\" value=\"$PROJECT_DIR$\" />\n    <property name=\"nodejs_interpreter_path.stuck_in_default_project\" value=\"undefined stuck path\" />\n    <property name=\"nodejs_npm_path_reset_for_default_project\" value=\"true\" />\n    <property name=\"settings.editor.selected.configurable\" value=\"com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable\" />\n  </component>\n  <component name=\"RecentsManager\">\n    <key name=\"CopyFile.RECENT_KEYS\">\n      <recent name=\"$PROJECT_DIR$/data/11/ISIC2018_Task1-2_Training_Input\" />\n      <recent name=\"$PROJECT_DIR$/data/11/ISIC2018_Task1_Training_GroundTruth\" />\n      <recent name=\"$PROJECT_DIR$/Datasets\" />\n    </key>\n    <key name=\"MoveFile.RECENT_KEYS\">\n      <recent name=\"$PROJECT_DIR$/data/11/ISIC2018_Task1_Training_GroundTruth\" />\n      <recent name=\"$PROJECT_DIR$/Datasets\" />\n    </key>\n  </component>\n  <component name=\"RunDashboard\">\n    <option name=\"ruleStates\">\n      <list>\n        <RuleState>\n          <option name=\"name\" value=\"ConfigurationTypeDashboardGroupingRule\" />\n        </RuleState>\n        <RuleState>\n          <option name=\"name\" value=\"StatusDashboardGroupingRule\" />\n        </RuleState>\n      </list>\n    </option>\n  </component>\n  <component name=\"RunManager\">\n    <configuration name=\"validation\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"CA-Net\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/validation.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <recent_temporary>\n      <list>\n        <item itemvalue=\"Python.validation\" />\n      </list>\n    </recent_temporary>\n  </component>\n  <component name=\"SvnConfiguration\">\n    <configuration />\n  </component>\n  <component name=\"TaskManager\">\n    <task active=\"true\" id=\"Default\" summary=\"Default task\">\n      <changelist id=\"2541e3bb-fbe2-4fc6-be8f-b6401cb16713\" name=\"Default Changelist\" comment=\"\" />\n      <created>1598531699004</created>\n      <option name=\"number\" value=\"Default\" />\n      <option name=\"presentableId\" value=\"Default\" />\n      <updated>1598531699004</updated>\n      <workItem from=\"1598531700459\" duration=\"3016000\" />\n      <workItem from=\"1598535247980\" duration=\"2358000\" />\n      <workItem from=\"1598537682047\" duration=\"117000\" />\n      <workItem from=\"1598538259149\" duration=\"289000\" />\n    </task>\n    <servers />\n  </component>\n  <component name=\"TimeTrackingManager\">\n    <option name=\"totallyTimeSpent\" value=\"5780000\" />\n  </component>\n  <component name=\"ToolWindowManager\">\n    <frame x=\"65\" y=\"-4\" width=\"1855\" height=\"1084\" extended-state=\"6\" />\n    <editor active=\"true\" />\n    <layout>\n      <window_info active=\"true\" content_ui=\"combo\" id=\"Project\" order=\"0\" visible=\"true\" weight=\"0.17819591\" />\n      <window_info id=\"Structure\" order=\"1\" side_tool=\"true\" weight=\"0.25\" />\n      <window_info id=\"Favorites\" order=\"2\" side_tool=\"true\" />\n      <window_info anchor=\"bottom\" id=\"Message\" order=\"0\" />\n      <window_info anchor=\"bottom\" id=\"Find\" order=\"1\" />\n      <window_info anchor=\"bottom\" id=\"Run\" order=\"2\" />\n      <window_info anchor=\"bottom\" id=\"Debug\" order=\"3\" weight=\"0.2761506\" />\n      <window_info anchor=\"bottom\" id=\"Cvs\" order=\"4\" weight=\"0.25\" />\n      <window_info anchor=\"bottom\" id=\"Inspection\" order=\"5\" weight=\"0.4\" />\n      <window_info anchor=\"bottom\" id=\"TODO\" order=\"6\" />\n      <window_info anchor=\"bottom\" id=\"Docker\" order=\"7\" show_stripe_button=\"false\" />\n      <window_info anchor=\"bottom\" id=\"Version Control\" order=\"8\" />\n      <window_info anchor=\"bottom\" id=\"Database Changes\" order=\"9\" />\n      <window_info anchor=\"bottom\" id=\"Event Log\" order=\"10\" side_tool=\"true\" />\n      <window_info anchor=\"bottom\" id=\"Terminal\" order=\"11\" visible=\"true\" weight=\"0.3294979\" />\n      <window_info anchor=\"bottom\" id=\"Python Console\" order=\"12\" />\n      <window_info anchor=\"right\" id=\"Commander\" internal_type=\"SLIDING\" order=\"0\" type=\"SLIDING\" weight=\"0.4\" />\n      <window_info anchor=\"right\" id=\"Ant Build\" order=\"1\" weight=\"0.25\" />\n      <window_info anchor=\"right\" content_ui=\"combo\" id=\"Hierarchy\" order=\"2\" weight=\"0.25\" />\n      <window_info anchor=\"right\" id=\"SciView\" order=\"3\" visible=\"true\" weight=\"0.12562259\" />\n      <window_info anchor=\"right\" id=\"Database\" order=\"4\" />\n    </layout>\n  </component>\n  <component name=\"TypeScriptGeneratedFilesManager\">\n    <option name=\"version\" value=\"1\" />\n  </component>\n  <component name=\"editorHistoryManager\">\n    <entry file=\"file://$PROJECT_DIR$/Datasets/folder1/folder1_test.list\" />\n    <entry file=\"file://$USER_HOME$/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/lib/npyio.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"358\">\n          <caret line=\"293\" lean-forward=\"true\" selection-start-line=\"293\" selection-end-line=\"293\" />\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/Datasets/ISIC2018.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"286\">\n          <caret line=\"13\" selection-start-line=\"13\" selection-end-line=\"13\" />\n          <folding>\n            <element signature=\"e#0#9#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/README.md\">\n      <provider selected=\"true\" editor-type-id=\"split-provider[text-editor;markdown-preview-editor]\">\n        <state split_layout=\"SPLIT\">\n          <first_editor relative-caret-position=\"1254\">\n            <caret line=\"57\" selection-start-line=\"57\" selection-end-line=\"57\" />\n          </first_editor>\n          <second_editor />\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/Models/networks/network.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"374\">\n          <caret line=\"17\" column=\"36\" selection-start-line=\"17\" selection-start-column=\"36\" selection-end-line=\"17\" selection-end-column=\"36\" />\n          <folding>\n            <element signature=\"e#0#12#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/Models/layers/nonlocal_layer.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state>\n          <folding>\n            <element signature=\"e#0#12#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/Models/layers/scale_attention_layer.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"2640\">\n          <caret line=\"120\" column=\"25\" selection-start-line=\"120\" selection-start-column=\"25\" selection-end-line=\"120\" selection-end-column=\"25\" />\n          <folding>\n            <element signature=\"e#0#12#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/Datasets/folder0/folder0_test.list\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state>\n          <caret column=\"16\" lean-forward=\"true\" selection-start-column=\"16\" selection-end-column=\"16\" />\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/utils/transform.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"479\">\n          <caret line=\"85\" column=\"17\" lean-forward=\"true\" selection-start-line=\"85\" selection-start-column=\"17\" selection-end-line=\"85\" selection-end-column=\"17\" />\n          <folding>\n            <element signature=\"e#0#12#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/validation.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"743\">\n          <caret line=\"118\" selection-start-line=\"118\" selection-end-line=\"118\" />\n          <folding>\n            <element signature=\"e#0#9#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/main.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"545\">\n          <caret line=\"428\" column=\"36\" selection-start-line=\"428\" selection-start-column=\"36\" selection-end-line=\"428\" selection-end-column=\"36\" />\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/create_folder.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"374\">\n          <caret line=\"17\" column=\"46\" lean-forward=\"true\" selection-start-line=\"17\" selection-start-column=\"46\" selection-end-line=\"17\" selection-end-column=\"46\" />\n          <folding>\n            <element signature=\"e#0#9#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n    <entry file=\"file://$PROJECT_DIR$/isic_preprocess.py\">\n      <provider selected=\"true\" editor-type-id=\"text-editor\">\n        <state relative-caret-position=\"550\">\n          <caret line=\"25\" column=\"38\" lean-forward=\"true\" selection-start-line=\"25\" selection-start-column=\"38\" selection-end-line=\"25\" selection-end-column=\"38\" />\n          <folding>\n            <element signature=\"e#143#152#0\" expanded=\"true\" />\n          </folding>\n        </state>\n      </provider>\n    </entry>\n  </component>\n</project>"
  },
  {
    "path": "Datasets/ISIC2018.py",
    "content": "import os\nimport PIL\nimport torch\nimport numpy as np\nimport nibabel as nib\nimport matplotlib.pyplot as plt\n\nfrom os import listdir\nfrom os.path import join\nfrom PIL import Image\nfrom utils.transform import itensity_normalize\nfrom torch.utils.data.dataset import Dataset\n\n\nclass ISIC2018_dataset(Dataset):\n    def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all',\n                 folder='folder0', train_type='train', transform=None):\n        self.transform = transform\n        self.train_type = train_type\n        self.folder_file = './Datasets/' + folder\n\n        if self.train_type in ['train', 'validation', 'test']:\n            # this is for cross validation\n            with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'),\n                      'r') as f:\n                self.image_list = f.readlines()\n            self.image_list = [item.replace('\\n', '') for item in self.image_list]\n            self.folder = [join(dataset_folder, 'image', x) for x in self.image_list]\n            self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list]\n            # self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in\n            #                       listdir(join(dataset_folder, self.train_type, 'image'))])\n            # self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in\n            #                     listdir(join(dataset_folder, self.train_type, 'label'))])\n        else:\n            print(\"Choosing type error, You have to choose the loading data type including: train, validation, test\")\n\n        assert len(self.folder) == len(self.mask)\n\n    def __getitem__(self, item: int):\n        image = np.load(self.folder[item])\n        label = np.load(self.mask[item])\n\n        sample = {'image': image, 'label': label}\n\n        if self.transform is not None:\n            # TODO: transformation to argument datasets\n            sample = self.transform(sample, self.train_type)\n\n        return sample['image'], sample['label']\n\n    def __len__(self):\n        return len(self.folder)\n\n# a = ISIC2018_dataset()\n"
  },
  {
    "path": "Datasets/folder0/folder0_test.list",
    "content": "ISIC_0010854.npy"
  },
  {
    "path": "Models/__init__.py",
    "content": ""
  },
  {
    "path": "Models/layers/__init__.py",
    "content": ""
  },
  {
    "path": "Models/layers/channel_attention_layer.py",
    "content": "import torch.nn as nn\n\n\n# # SE block add to U-net\ndef conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)\n\n\nclass SE_Conv_Block(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False):\n        super(SE_Conv_Block, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes * 2)\n        self.bn2 = nn.BatchNorm2d(planes * 2)\n        self.conv3 = conv3x3(planes * 2, planes)\n        self.bn3 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n        self.dropout = drop_out\n\n        if planes <= 16:\n            self.globalAvgPool = nn.AvgPool2d((224, 300), stride=1)  # (224, 300) for ISIC2018\n            self.globalMaxPool = nn.MaxPool2d((224, 300), stride=1)\n        elif planes == 32:\n            self.globalAvgPool = nn.AvgPool2d((112, 150), stride=1)  # (112, 150) for ISIC2018\n            self.globalMaxPool = nn.MaxPool2d((112, 150), stride=1)\n        elif planes == 64:\n            self.globalAvgPool = nn.AvgPool2d((56, 75), stride=1)    # (56, 75) for ISIC2018\n            self.globalMaxPool = nn.MaxPool2d((56, 75), stride=1)\n        elif planes == 128:\n            self.globalAvgPool = nn.AvgPool2d((28, 37), stride=1)    # (28, 37) for ISIC2018\n            self.globalMaxPool = nn.MaxPool2d((28, 37), stride=1)\n        elif planes == 256:\n            self.globalAvgPool = nn.AvgPool2d((14, 18), stride=1)    # (14, 18) for ISIC2018\n            self.globalMaxPool = nn.MaxPool2d((14, 18), stride=1)\n\n        self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2))\n        self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2)\n        self.sigmoid = nn.Sigmoid()\n\n        self.downchannel = None\n        if inplanes != planes:\n            self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False),\n                                             nn.BatchNorm2d(planes * 2),)\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downchannel is not None:\n            residual = self.downchannel(x)\n\n        original_out = out\n        out1 = out\n        # For global average pool\n        out = self.globalAvgPool(out)\n        out = out.view(out.size(0), -1)\n        out = self.fc1(out)\n        out = self.relu(out)\n        out = self.fc2(out)\n        out = self.sigmoid(out)\n        out = out.view(out.size(0), out.size(1), 1, 1)\n        avg_att = out\n        out = out * original_out\n        # For global maximum pool\n        out1 = self.globalMaxPool(out1)\n        out1 = out1.view(out1.size(0), -1)\n        out1 = self.fc1(out1)\n        out1 = self.relu(out1)\n        out1 = self.fc2(out1)\n        out1 = self.sigmoid(out1)\n        out1 = out1.view(out1.size(0), out1.size(1), 1, 1)\n        max_att = out1\n        out1 = out1 * original_out\n\n        att_weight = avg_att + max_att\n        out += out1\n        out += residual\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n        out = self.relu(out)\n        if self.dropout:\n            out = nn.Dropout2d(0.5)(out)\n\n        return out, att_weight\n"
  },
  {
    "path": "Models/layers/grid_attention_layer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom Models.networks_other import init_weights\n\n\nclass _GridAttentionBlockND(nn.Module):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',\n                 sub_sample_factor=(2,2,2)):\n        super(_GridAttentionBlockND, self).__init__()\n\n        assert dimension in [2, 3]\n        assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual']\n\n        # Downsampling rate for the input featuremap\n        if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor\n        elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)\n        else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension\n\n        # Default parameter set\n        self.mode = mode\n        self.dimension = dimension\n        self.sub_sample_kernel_size = self.sub_sample_factor\n\n        # Number of channels (pixel dimensions)\n        self.in_channels = in_channels\n        self.gating_channels = gating_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            bn = nn.BatchNorm3d\n            self.upsample_mode = 'trilinear'\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            bn = nn.BatchNorm2d\n            self.upsample_mode = 'bilinear'\n        else:\n            raise NotImplemented\n\n        # Output transform\n        self.W = nn.Sequential(\n            conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),\n            bn(self.in_channels),\n        )\n\n        # Theta^T * x_ij + Phi^T * gating_signal + bias\n        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                             kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True)\n        self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,\n                           kernel_size=(1, 1), stride=1, padding=0, bias=True)\n        self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n\n        # Initialise weights\n        for m in self.children():\n            init_weights(m, init_type='kaiming')\n\n        # Define the operation\n        if mode == 'concatenation':\n            self.operation_function = self._concatenation\n        elif mode == 'concatenation_debug':\n            self.operation_function = self._concatenation_debug\n        elif mode == 'concatenation_residual':\n            self.operation_function = self._concatenation_residual\n        else:\n            raise NotImplementedError('Unknown operation function.')\n\n\n    def forward(self, x, g):\n        '''\n        :param x: (b, c, t, h, w)\n        :param g: (b, g_d)\n        :return:\n        '''\n\n        output = self.operation_function(x, g)\n        return output\n\n    def _concatenation(self, x, g):\n        input_size = x.size()\n        batch_size = input_size[0]\n        assert batch_size == g.size(0)\n\n        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)\n        # phi   => (b, g_d) -> (b, i_c)\n        theta_x = self.theta(x)\n        theta_x_size = theta_x.size()\n\n        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')\n        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)\n        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)\n        f = F.relu(theta_x + phi_g, inplace=True)\n\n        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)\n        sigm_psi_f = F.sigmoid(self.psi(f))\n\n        # upsample the attentions and multiply\n        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)\n        y = sigm_psi_f.expand_as(x) * x\n        W_y = self.W(y)\n\n        return W_y, sigm_psi_f\n\n    def _concatenation_debug(self, x, g):\n        input_size = x.size()\n        batch_size = input_size[0]\n        assert batch_size == g.size(0)\n\n        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)\n        # phi   => (b, g_d) -> (b, i_c)\n        theta_x = self.theta(x)\n        theta_x_size = theta_x.size()\n\n        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')\n        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)\n        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)\n        f = F.softplus(theta_x + phi_g)\n\n        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)\n        sigm_psi_f = F.sigmoid(self.psi(f))\n\n        # upsample the attentions and multiply\n        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)\n        y = sigm_psi_f.expand_as(x) * x\n        W_y = self.W(y)\n\n        return W_y, sigm_psi_f\n\n\n    def _concatenation_residual(self, x, g):\n        input_size = x.size()\n        batch_size = input_size[0]\n        assert batch_size == g.size(0)\n\n        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)\n        # phi   => (b, g_d) -> (b, i_c)\n        theta_x = self.theta(x)\n        theta_x_size = theta_x.size()\n\n        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')\n        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)\n        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)\n        f = F.relu(theta_x + phi_g, inplace=True)\n\n        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)\n        f = self.psi(f).view(batch_size, 1, -1)\n        sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:])\n\n        # upsample the attentions and multiply\n        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)\n        y = sigm_psi_f.expand_as(x) * x\n        W_y = self.W(y)\n\n        return W_y, sigm_psi_f\n\n\nclass GridAttentionBlock2D(_GridAttentionBlockND):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',\n                 sub_sample_factor=(2, 2)):\n        super(GridAttentionBlock2D, self).__init__(in_channels,\n                                                   inter_channels=inter_channels,\n                                                   gating_channels=gating_channels,\n                                                   dimension=2, mode=mode,\n                                                   sub_sample_factor=sub_sample_factor,\n                                                   )\n\n\nclass GridAttentionBlock3D(_GridAttentionBlockND):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',\n                 sub_sample_factor=(2,2,2)):\n        super(GridAttentionBlock3D, self).__init__(in_channels,\n                                                   inter_channels=inter_channels,\n                                                   gating_channels=gating_channels,\n                                                   dimension=3, mode=mode,\n                                                   sub_sample_factor=sub_sample_factor,\n                                                   )\n\nclass _GridAttentionBlockND_TORR(nn.Module):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',\n                 sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'):\n        super(_GridAttentionBlockND_TORR, self).__init__()\n\n        assert dimension in [2, 3]\n        assert mode in ['concatenation', 'concatenation_softmax',\n                        'concatenation_sigmoid', 'concatenation_mean',\n                        'concatenation_range_normalise', 'concatenation_mean_flow']\n\n        # Default parameter set\n        self.mode = mode\n        self.dimension = dimension\n        self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension\n        self.sub_sample_kernel_size = self.sub_sample_factor\n\n        # Number of channels (pixel dimensions)\n        self.in_channels = in_channels\n        self.gating_channels = gating_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            bn = nn.BatchNorm3d\n            self.upsample_mode = 'trilinear'\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            bn = nn.BatchNorm2d\n            self.upsample_mode = 'bilinear'\n        else:\n            raise NotImplemented\n\n        # initialise id functions\n        # Theta^T * x_ij + Phi^T * gating_signal + bias\n        self.W = lambda x: x\n        self.theta = lambda x: x\n        self.psi = lambda x: x\n        self.phi = lambda x: x\n        self.nl1 = lambda x: x\n\n        if use_W:\n            if bn_layer:\n                self.W = nn.Sequential(\n                    conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),\n                    bn(self.in_channels),\n                )\n            else:\n                self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)\n\n        if use_theta:\n            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                                 kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)\n\n\n        if use_phi:\n            self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,\n                               kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)\n\n\n        if use_psi:\n            self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)\n\n\n        if nonlinearity1:\n            if nonlinearity1 == 'relu':\n                self.nl1 = lambda x: F.relu(x, inplace=True)\n\n        if 'concatenation' in mode:\n            self.operation_function = self._concatenation\n        else:\n            raise NotImplementedError('Unknown operation function.')\n\n        # Initialise weights\n        for m in self.children():\n            init_weights(m, init_type='kaiming')\n\n\n        if use_psi and self.mode == 'concatenation_sigmoid':\n            nn.init.constant(self.psi.bias.data, 3.0)\n\n        if use_psi and self.mode == 'concatenation_softmax':\n            nn.init.constant(self.psi.bias.data, 10.0)\n\n        # if use_psi and self.mode == 'concatenation_mean':\n        #     nn.init.constant(self.psi.bias.data, 3.0)\n\n        # if use_psi and self.mode == 'concatenation_range_normalise':\n        #     nn.init.constant(self.psi.bias.data, 3.0)\n\n        parallel = False\n        if parallel:\n            if use_W: self.W = nn.DataParallel(self.W)\n            if use_phi: self.phi = nn.DataParallel(self.phi)\n            if use_psi: self.psi = nn.DataParallel(self.psi)\n            if use_theta: self.theta = nn.DataParallel(self.theta)\n\n    def forward(self, x, g):\n        '''\n        :param x: (b, c, t, h, w)\n        :param g: (b, g_d)\n        :return:\n        '''\n\n        output = self.operation_function(x, g)\n        return output\n\n    def _concatenation(self, x, g):\n        input_size = x.size()\n        batch_size = input_size[0]\n        assert batch_size == g.size(0)\n\n        #############################\n        # compute compatibility score\n\n        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w)\n        # phi   => (b, c, t, h, w) -> (b, i_c, t, h, w)\n        theta_x = self.theta(x)\n        theta_x_size = theta_x.size()\n\n        #  nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3)\n        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)\n\n        f = theta_x + phi_g\n        f = self.nl1(f)\n\n        psi_f = self.psi(f)\n\n        ############################################\n        # normalisation -- scale compatibility score\n        #  psi^T . f -> (b, 1, t/s1, h/s2, w/s3)\n        if self.mode == 'concatenation_softmax':\n            sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2)\n            sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])\n        elif self.mode == 'concatenation_mean':\n            psi_f_flat = psi_f.view(batch_size, 1, -1)\n            psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6)\n            psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat)\n\n            sigm_psi_f = psi_f_flat / psi_f_sum\n            sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])\n        elif self.mode == 'concatenation_mean_flow':\n            psi_f_flat = psi_f.view(batch_size, 1, -1)\n            ss = psi_f_flat.shape\n            psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1)\n            psi_f_flat = psi_f_flat - psi_f_min\n            psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat)\n\n            sigm_psi_f = psi_f_flat / psi_f_sum\n            sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])\n        elif self.mode == 'concatenation_range_normalise':\n            psi_f_flat = psi_f.view(batch_size, 1, -1)\n            ss = psi_f_flat.shape\n            psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)\n            psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1)\n\n            sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat)\n            sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:])\n\n        elif self.mode == 'concatenation_sigmoid':\n            sigm_psi_f = F.sigmoid(psi_f)\n        else:\n            raise NotImplementedError\n\n        # sigm_psi_f is attention map! upsample the attentions and multiply\n        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)\n        y = sigm_psi_f.expand_as(x) * x\n        W_y = self.W(y)\n\n        return W_y, sigm_psi_f\n\n\nclass GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',\n                 sub_sample_factor=(1,1), bn_layer=True,\n                 use_W=True, use_phi=True, use_theta=True, use_psi=True,\n                 nonlinearity1='relu'):\n        super(GridAttentionBlock2D_TORR, self).__init__(in_channels,\n                                               inter_channels=inter_channels,\n                                               gating_channels=gating_channels,\n                                               dimension=2, mode=mode,\n                                               sub_sample_factor=sub_sample_factor,\n                                               bn_layer=bn_layer,\n                                               use_W=use_W,\n                                               use_phi=use_phi,\n                                               use_theta=use_theta,\n                                               use_psi=use_psi,\n                                               nonlinearity1=nonlinearity1)\n\n\nclass GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR):\n    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',\n                 sub_sample_factor=(1,1,1), bn_layer=True):\n        super(GridAttentionBlock3D_TORR, self).__init__(in_channels,\n                                                   inter_channels=inter_channels,\n                                                   gating_channels=gating_channels,\n                                                   dimension=3, mode=mode,\n                                                   sub_sample_factor=sub_sample_factor,\n                                                   bn_layer=bn_layer)\n\n\nclass MultiAttentionBlock(nn.Module):\n    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):\n        super(MultiAttentionBlock, self).__init__()\n        self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,\n                                                 inter_channels=inter_size, mode=nonlocal_mode,\n                                                 sub_sample_factor=sub_sample_factor)\n        self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size,\n                                                 inter_channels=inter_size, mode=nonlocal_mode,\n                                                 sub_sample_factor=sub_sample_factor)\n        self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),\n                                           nn.BatchNorm2d(in_size),\n                                           nn.ReLU(inplace=True))\n\n        # initialise the blocks\n        for m in self.children():\n            if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue\n            init_weights(m, init_type='kaiming')\n\n    def forward(self, input, gating_signal):\n        gate_1, attention_1 = self.gate_block_1(input, gating_signal)\n        gate_2, attention_2 = self.gate_block_2(input, gating_signal)\n\n        return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)\n\n\nif __name__ == '__main__':\n    from torch.autograd import Variable\n\n    mode_list = ['concatenation']\n\n    for mode in mode_list:\n\n        img = Variable(torch.rand(2, 16, 10, 10, 10))\n        gat = Variable(torch.rand(2, 64, 4, 4, 4))\n        net = GridAttentionBlock3D(in_channels=16, inter_channels=16, gating_channels=64, mode=mode, sub_sample_factor=(2,2,2))\n        out, sigma = net(img, gat)\n        print(out.size())\n"
  },
  {
    "path": "Models/layers/modules.py",
    "content": "import torch\nimport torch.nn as nn\n\n\ndef conv1x1(in_planes, out_planes, stride=1, bias=False):\n    \"1x1 convolution\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,\n                     padding=0, bias=bias)\n\n\ndef conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)\n\n\n# conv_block(nn.Module) for U-net convolution block\nclass conv_block(nn.Module):\n    def __init__(self, ch_in, ch_out, drop_out=False):\n        super(conv_block, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True),\n        )\n        self.dropout = drop_out\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.dropout:\n            x = nn.Dropout2d(0.5)(x)\n        return x\n\n\n# # UpCat(nn.Module) for U-net UP convolution\nclass UpCat(nn.Module):\n    def __init__(self, in_feat, out_feat, is_deconv=True):\n        super(UpCat, self).__init__()\n\n        if is_deconv:\n            self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)\n        else:\n            self.up = nn.Upsample(scale_factor=2, mode='bilinear')\n\n    def forward(self, inputs, down_outputs):\n        # TODO: Upsampling required after deconv?\n        outputs = self.up(down_outputs)\n        offset = inputs.size()[3] - outputs.size()[3]\n        if offset == 1:\n            addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(\n                3).cuda()\n            outputs = torch.cat([outputs, addition], dim=3)\n        elif offset > 1:\n            addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()\n            outputs = torch.cat([outputs, addition], dim=3)\n        out = torch.cat([inputs, outputs], dim=1)\n\n        return out\n\n\n# # UpCatconv(nn.Module) for up convolution\nclass UpCatconv(nn.Module):\n    def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False):\n        super(UpCatconv, self).__init__()\n\n        if is_deconv:\n            self.conv = conv_block(in_feat, out_feat, drop_out=drop_out)\n            self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)\n        else:\n            self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out)\n            self.up = nn.Upsample(scale_factor=2, mode='bilinear')\n\n    def forward(self, inputs, down_outputs):\n        # TODO: Upsampling required after deconv\n        outputs = self.up(down_outputs)\n        offset = inputs.size()[3] - outputs.size()[3]\n        if offset == 1:\n            addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze(\n                3).cuda()\n            outputs = torch.cat([outputs, addition], dim=3)\n        elif offset > 1:\n            addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda()\n            outputs = torch.cat([outputs, addition], dim=3)\n        out = self.conv(torch.cat([inputs, outputs], dim=1))\n\n        return out\n\n\n# # UnetGridGatingSignal3(nn.Module)\nclass UnetGridGatingSignal3(nn.Module):\n    def __init__(self, in_size, out_size, kernel_size=(1, 1), is_batchnorm=True):\n        super(UnetGridGatingSignal3, self).__init__()\n\n        if is_batchnorm:\n            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),\n                                       nn.BatchNorm2d(out_size),\n                                       nn.ReLU(inplace=True),\n                                       )\n        else:\n            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),\n                                       nn.ReLU(inplace=True),\n                                       )\n\n    def forward(self, inputs):\n        outputs = self.conv1(inputs)\n        return outputs\n\n\nclass UnetDsv3(nn.Module):\n    def __init__(self, in_size, out_size, scale_factor):\n        super(UnetDsv3, self).__init__()\n        self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0),\n                                 nn.Upsample(size=scale_factor, mode='bilinear'), )\n\n    def forward(self, input):\n        return self.dsv(input)\n"
  },
  {
    "path": "Models/layers/nonlocal_layer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom Models.networks_other import init_weights\n\n\nclass _NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',\n                 sub_sample_factor=4, bn_layer=True):\n        super(_NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n        assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']\n\n        # print('Dimension: %d, mode: %s' % (dimension, mode))\n\n        self.mode = mode\n        self.dimension = dimension\n        self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]\n\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            max_pool = nn.MaxPool3d\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool = nn.MaxPool2d\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool = nn.MaxPool1d\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                         kernel_size=1, stride=1, padding=0)\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                        kernel_size=1, stride=1, padding=0),\n                bn(self.in_channels)\n            )\n            nn.init.constant(self.W[1].weight, 0)\n            nn.init.constant(self.W[1].bias, 0)\n        else:\n            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,\n                             kernel_size=1, stride=1, padding=0)\n            nn.init.constant(self.W.weight, 0)\n            nn.init.constant(self.W.bias, 0)\n\n        self.theta = None\n        self.phi = None\n\n        if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:\n            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                                 kernel_size=1, stride=1, padding=0)\n            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,\n                               kernel_size=1, stride=1, padding=0)\n\n            if mode in ['concatenation']:\n                self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)\n                self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)\n            elif mode in ['concat_proper', 'concat_proper_down']:\n                self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,\n                                     padding=0, bias=True)\n\n        if mode == 'embedded_gaussian':\n            self.operation_function = self._embedded_gaussian\n        elif mode == 'dot_product':\n            self.operation_function = self._dot_product\n        elif mode == 'gaussian':\n            self.operation_function = self._gaussian\n        elif mode == 'concatenation':\n            self.operation_function = self._concatenation\n        elif mode == 'concat_proper':\n            self.operation_function = self._concatenation_proper\n        elif mode == 'concat_proper_down':\n            self.operation_function = self._concatenation_proper_down\n        else:\n            raise NotImplementedError('Unknown operation function.')\n\n        if any(ss > 1 for ss in self.sub_sample_factor):\n            self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))\n            if self.phi is None:\n                self.phi = max_pool(kernel_size=sub_sample_factor)\n            else:\n                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))\n            if mode == 'concat_proper_down':\n                self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))\n\n        # Initialise weights\n        for m in self.children():\n            init_weights(m, init_type='kaiming')\n\n    def forward(self, x):\n        '''\n        :param x: (b, c, t, h, w)\n        :return:\n        '''\n\n        output = self.operation_function(x)\n        return output\n\n    def _embedded_gaussian(self, x):\n        batch_size = x.size(0)\n\n        # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)\n        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)\n        # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n    def _gaussian(self, x):\n        batch_size = x.size(0)\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = x.view(batch_size, self.in_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n\n        if self.sub_sample_factor > 1:\n            phi_x = self.phi(x).view(batch_size, self.in_channels, -1)\n        else:\n            phi_x = x.view(batch_size, self.in_channels, -1)\n\n        f = torch.matmul(theta_x, phi_x)\n        f_div_C = F.softmax(f, dim=-1)\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n    def _dot_product(self, x):\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        N = f.size(-1)\n        f_div_C = f / N\n\n        y = torch.matmul(f_div_C, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n    def _concatenation(self, x):\n        batch_size = x.size(0)\n\n        # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)\n        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c)\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)\n\n        # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw)\n        # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw)\n        # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw)\n        f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \\\n            self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1))\n        f = F.relu(f, inplace=True)\n\n        # Normalise the relations\n        N = f.size(-1)\n        f_div_c = f / N\n\n        # g(x_j) * f(x_j, x_i)\n        # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)\n        y = torch.matmul(g_x, f_div_c)\n        y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n    def _concatenation_proper(self, x):\n        batch_size = x.size(0)\n\n        # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)\n        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n\n        # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)\n        # phi => (b, 0.5c, thw/s**2) ->  (expand) (b, 0.5c, thw/s**2, thw)\n        # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)\n        f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \\\n            phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))\n        f = F.relu(f, inplace=True)\n\n        # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)\n        f = torch.squeeze(self.psi(f), dim=1)\n\n        # Normalise the relations\n        f_div_c = F.softmax(f, dim=1)\n\n        # g(x_j) * f(x_j, x_i)\n        # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)\n        y = torch.matmul(g_x, f_div_c)\n        y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n    def _concatenation_proper_down(self, x):\n        batch_size = x.size(0)\n\n        # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n\n        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)\n        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)\n        theta_x = self.theta(x)\n        downsampled_size = theta_x.size()\n        theta_x = theta_x.view(batch_size, self.inter_channels, -1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n\n        # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)\n        # phi => (b, 0.5, thw/s**2) ->  (expand) (b, 0.5c, thw/s**2, thw)\n        # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)\n        f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \\\n            phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))\n        f = F.relu(f, inplace=True)\n\n        # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)\n        f = torch.squeeze(self.psi(f), dim=1)\n\n        # Normalise the relations\n        f_div_c = F.softmax(f, dim=1)\n\n        # g(x_j) * f(x_j, x_i)\n        # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)\n        y = torch.matmul(g_x, f_div_c)\n        y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:])\n\n        # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3)\n        y = F.upsample(y, size=x.size()[2:], mode='trilinear')\n\n        # attention block output\n        W_y = self.W(y)\n        z = W_y + x\n\n        return z\n\n\nclass NONLocalBlock1D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):\n        super(NONLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, mode=mode,\n                                              sub_sample_factor=sub_sample_factor,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock2D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):\n        super(NONLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, mode=mode,\n                                              sub_sample_factor=sub_sample_factor,\n                                              bn_layer=bn_layer)\n\n\nclass NONLocalBlock3D(_NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):\n        super(NONLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, mode=mode,\n                                              sub_sample_factor=sub_sample_factor,\n                                              bn_layer=bn_layer)\n\n\nif __name__ == '__main__':\n    from torch.autograd import Variable\n\n    mode_list = ['concatenation']\n    #mode_list = ['embedded_gaussian', 'gaussian', 'dot_product', ]\n\n    for mode in mode_list:\n        print(mode)\n        img = Variable(torch.zeros(2, 4, 5))\n        net = NONLocalBlock1D(4, mode=mode, sub_sample_factor=2)\n        out = net(img)\n        print(out.size())\n\n        img = Variable(torch.zeros(2, 4, 5, 3))\n        net = NONLocalBlock2D(4, mode=mode, sub_sample_factor=1, bn_layer=False)\n        out = net(img)\n        print(out.size())\n\n        img = Variable(torch.zeros(2, 4, 5, 4, 5))\n        net = NONLocalBlock3D(4, mode=mode)\n        out = net(img)\n        print(out.size())\n"
  },
  {
    "path": "Models/layers/scale_attention_layer.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\ndef conv1x1(in_planes, out_planes, stride=1, bias=False):\n    \"1x1 convolution\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,\n                     padding=0, bias=bias)\n\n\n# # SE block add to U-net\ndef conv3x3(in_planes, out_planes, stride=1, bias=False, group=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias)\n\n\n# # CBAM Convolutional block attention module\nclass BasicConv(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,\n                 relu=True, bn=True, bias=False):\n        super(BasicConv, self).__init__()\n        self.out_channels = out_planes\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,\n                              dilation=dilation, groups=groups, bias=bias)\n        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None\n        self.relu = nn.ReLU() if relu else None\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu is not None:\n            x = self.relu(x)\n        return x\n\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\n\nclass ChannelGate(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):\n        super(ChannelGate, self).__init__()\n        self.gate_channels = gate_channels\n        self.mlp = nn.Sequential(\n            Flatten(),\n            nn.Linear(gate_channels, gate_channels // reduction_ratio),\n            nn.ReLU(),\n            nn.Linear(gate_channels // reduction_ratio, gate_channels)\n            )\n        self.pool_types = pool_types\n\n    def forward(self, x):\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type == 'avg':\n                avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp(avg_pool)\n            elif pool_type == 'max':\n                max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp(max_pool)\n            elif pool_type == 'lp':\n                lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n                channel_att_raw = self.mlp(lp_pool)\n            elif pool_type == 'lse':\n                # LSE pool only\n                lse_pool = logsumexp_2d(x)\n                channel_att_raw = self.mlp(lse_pool)\n\n            if channel_att_sum is None:\n                channel_att_sum = channel_att_raw\n            else:\n                channel_att_sum = channel_att_sum + channel_att_raw\n\n        # scalecoe = F.sigmoid(channel_att_sum)\n        channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4)\n        avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2)\n        avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16)\n        scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x)\n\n        return x * scale, scale\n\n\ndef logsumexp_2d(tensor):\n    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)\n    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)\n    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()\n    return outputs\n\n\nclass ChannelPool(nn.Module):\n    def forward(self, x):\n        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n\n\nclass SpatialGate(nn.Module):\n    def __init__(self):\n        super(SpatialGate, self).__init__()\n        kernel_size = 7\n        self.compress = ChannelPool()\n        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)\n\n    def forward(self, x):\n        x_compress = self.compress(x)\n        x_out = self.spatial(x_compress)\n        scale = F.sigmoid(x_out)    # broadcasting\n        # spa_scale = scale.expand_as(x)\n        # print(spa_scale.shape)\n        return x * scale, scale\n\nclass SpatialAtten(nn.Module):\n    def __init__(self, in_size, out_size, kernel_size=3, stride=1):\n        super(SpatialAtten, self).__init__()\n        self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride,\n                               padding=(kernel_size-1) // 2, relu=True)\n        self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride,\n                               padding=0, relu=True, bn=False)\n\n    def forward(self, x):\n        residual = x\n        x_out = self.conv1(x)\n        x_out = self.conv2(x_out)\n        spatial_att = F.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3)\n        spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape(\n                                        spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4])\n        x_out = residual * spatial_att\n\n        x_out += residual\n\n        return x_out, spatial_att\n\nclass Scale_atten_block(nn.Module):\n    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):\n        super(Scale_atten_block, self).__init__()\n        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial = no_spatial\n        if not no_spatial:\n            self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio)\n\n    def forward(self, x):\n        x_out, ca_atten = self.ChannelGate(x)\n        if not self.no_spatial:\n            x_out, sa_atten = self.SpatialGate(x_out)\n\n        return x_out, ca_atten, sa_atten\n\n\nclass scale_atten_convblock(nn.Module):\n    def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False):\n        super(scale_atten_convblock, self).__init__()\n        # if stride != 1 or in_size != out_size:\n        #     downsample = nn.Sequential(\n        #         nn.Conv2d(in_size, out_size,\n        #                   kernel_size=1, stride=stride, bias=False),\n        #         nn.BatchNorm2d(out_size),\n        #     )\n        self.downsample = downsample\n        self.stride = stride\n        self.no_spatial = no_spatial\n        self.dropout = drop_out\n\n        self.relu = nn.ReLU(inplace=True)\n        self.conv3 = conv3x3(in_size, out_size)\n        self.bn3 = nn.BatchNorm2d(out_size)\n\n        if use_cbam:\n            self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial)  # out_size\n        else:\n            self.cbam = None\n\n    def forward(self, x):\n        residual = x\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        if not self.cbam is None:\n            out, scale_c_atten, scale_s_atten = self.cbam(x)\n\n            # scale_c_atten = nn.Sigmoid()(scale_c_atten)\n            # scale_s_atten = nn.Sigmoid()(scale_s_atten)\n            # scale_atten = channel_atten_c * spatial_atten_s\n\n        # scale_max = torch.argmax(scale_atten, dim=1, keepdim=True)\n        # scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8)\n        # scale_max_soft = scale_max_soft.permute(0, 3, 1, 2)\n        # scale_atten_soft = scale_atten * scale_max_soft\n\n        out += residual\n        out = self.relu(out)\n        out = self.conv3(out)\n        out = self.bn3(out)\n        out = self.relu(out)\n\n        if self.dropout:\n            out = nn.Dropout2d(0.5)(out)\n\n        return out"
  },
  {
    "path": "Models/networks/network.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom Models.layers.modules import conv_block, UpCat, UpCatconv, UnetDsv3, UnetGridGatingSignal3\nfrom Models.layers.grid_attention_layer import GridAttentionBlock2D, MultiAttentionBlock\nfrom Models.layers.channel_attention_layer import SE_Conv_Block\nfrom Models.layers.scale_attention_layer import scale_atten_convblock\nfrom Models.layers.nonlocal_layer import NONLocalBlock2D\n\n\nclass Comprehensive_Atten_Unet(nn.Module):\n    def __init__(self, args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True,\n                 nonlocal_mode='concatenation', attention_dsample=(1, 1)):\n        super(Comprehensive_Atten_Unet, self).__init__()\n        self.args = args\n        self.is_deconv = is_deconv\n        self.in_channels = in_ch\n        self.num_classes = n_classes\n        self.is_batchnorm = is_batchnorm\n        self.feature_scale = feature_scale\n        self.out_size = args.out_size\n\n        filters = [64, 128, 256, 512, 1024]\n        filters = [int(x / self.feature_scale) for x in filters]\n\n        # downsampling\n        self.conv1 = conv_block(self.in_channels, filters[0])\n        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))\n\n        self.conv2 = conv_block(filters[0], filters[1])\n        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))\n\n        self.conv3 = conv_block(filters[1], filters[2])\n        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))\n\n        self.conv4 = conv_block(filters[2], filters[3], drop_out=True)\n        self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2))\n\n        self.center = conv_block(filters[3], filters[4], drop_out=True)\n\n        # attention blocks\n        # self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1],\n        #                                             inter_channels=filters[0])\n        self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],\n                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)\n        self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],\n                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample)\n        self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4)\n\n        # upsampling\n        self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv)\n        self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv)\n        self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv)\n        self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv)\n        self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True)\n        self.up3 = SE_Conv_Block(filters[3], filters[2])\n        self.up2 = SE_Conv_Block(filters[2], filters[1])\n        self.up1 = SE_Conv_Block(filters[1], filters[0])\n\n        # deep supervision\n        self.dsv4 = UnetDsv3(in_size=filters[3], out_size=4, scale_factor=self.out_size)\n        self.dsv3 = UnetDsv3(in_size=filters[2], out_size=4, scale_factor=self.out_size)\n        self.dsv2 = UnetDsv3(in_size=filters[1], out_size=4, scale_factor=self.out_size)\n        self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=4, kernel_size=1)\n\n        self.scale_att = scale_atten_convblock(in_size=16, out_size=4)\n        # final conv (without any concat)\n        self.final = nn.Sequential(nn.Conv2d(4, n_classes, kernel_size=1), nn.Softmax2d())\n\n    def forward(self, inputs):\n        # Feature Extraction\n        conv1 = self.conv1(inputs)\n        maxpool1 = self.maxpool1(conv1)\n\n        conv2 = self.conv2(maxpool1)\n        maxpool2 = self.maxpool2(conv2)\n\n        conv3 = self.conv3(maxpool2)\n        maxpool3 = self.maxpool3(conv3)\n\n        conv4 = self.conv4(maxpool3)\n        maxpool4 = self.maxpool4(conv4)\n\n        # Gating Signal Generation\n        center = self.center(maxpool4)\n\n        # Attention Mechanism\n        # Upscaling Part (Decoder)\n        up4 = self.up_concat4(conv4, center)\n        g_conv4 = self.nonlocal4_2(up4)\n\n        up4, att_weight4 = self.up4(g_conv4)\n        g_conv3, att3 = self.attentionblock3(conv3, up4)\n\n        # atten3_map = att3.cpu().detach().numpy().astype(np.float)\n        # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2],\n        #                                                      300 / atten3_map.shape[3]], order=0)\n\n        up3 = self.up_concat3(g_conv3, up4)\n        up3, att_weight3 = self.up3(up3)\n        g_conv2, att2 = self.attentionblock2(conv2, up3)\n\n        # atten2_map = att2.cpu().detach().numpy().astype(np.float)\n        # atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2],\n        #                                                      300 / atten2_map.shape[3]], order=0)\n\n        up2 = self.up_concat2(g_conv2, up3)\n        up2, att_weight2 = self.up2(up2)\n        # g_conv1, att1 = self.attentionblock1(conv1, up2)\n\n        # atten1_map = att1.cpu().detach().numpy().astype(np.float)\n        # atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2],\n        #                                                      300 / atten1_map.shape[3]], order=0)\n        up1 = self.up_concat1(conv1, up2)\n        up1, att_weight1 = self.up1(up1)\n\n        # Deep Supervision\n        dsv4 = self.dsv4(up4)\n        dsv3 = self.dsv3(up3)\n        dsv2 = self.dsv2(up2)\n        dsv1 = self.dsv1(up1)\n        dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1)\n        out = self.scale_att(dsv_cat)\n\n        out = self.final(out)\n\n        return out\n"
  },
  {
    "path": "Models/networks_other.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.autograd import Variable\nfrom torch.optim import lr_scheduler\nimport time\nimport numpy as np\n###############################################################################\n# Functions\n###############################################################################\n\n\ndef weights_init_normal(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.normal(m.weight.data, 0.0, 0.02)\n    elif classname.find('Linear') != -1:\n        init.normal(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        init.normal(m.weight.data, 1.0, 0.02)\n        init.constant(m.bias.data, 0.0)\n\n\ndef weights_init_xavier(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.xavier_normal(m.weight.data, gain=1)\n    elif classname.find('Linear') != -1:\n        init.xavier_normal(m.weight.data, gain=1)\n    elif classname.find('BatchNorm') != -1:\n        init.normal(m.weight.data, 1.0, 0.02)\n        init.constant(m.bias.data, 0.0)\n\n\ndef weights_init_kaiming(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('Linear') != -1:\n        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('BatchNorm') != -1:\n        init.normal(m.weight.data, 1.0, 0.02)\n        init.constant(m.bias.data, 0.0)\n\n\ndef weights_init_orthogonal(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.orthogonal(m.weight.data, gain=1)\n    elif classname.find('Linear') != -1:\n        init.orthogonal(m.weight.data, gain=1)\n    elif classname.find('BatchNorm') != -1:\n        init.normal(m.weight.data, 1.0, 0.02)\n        init.constant(m.bias.data, 0.0)\n\n\ndef init_weights(net, init_type='normal'):\n    #print('initialization method [%s]' % init_type)\n    if init_type == 'normal':\n        net.apply(weights_init_normal)\n    elif init_type == 'xavier':\n        net.apply(weights_init_xavier)\n    elif init_type == 'kaiming':\n        net.apply(weights_init_kaiming)\n    elif init_type == 'orthogonal':\n        net.apply(weights_init_orthogonal)\n    else:\n        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n\n\ndef get_norm_layer(norm_type='instance'):\n    if norm_type == 'batch':\n        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)\n    elif norm_type == 'instance':\n        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)\n    elif norm_type == 'none':\n        norm_layer = None\n    else:\n        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)\n    return norm_layer\n\n\ndef adjust_learning_rate(optimizer, lr):\n    \"\"\"Sets the learning rate to a fixed number\"\"\"\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n\ndef get_scheduler(optimizer, opt):\n    print('opt.lr_policy = [{}]'.format(opt.lr_policy))\n    if opt.lr_policy == 'lambda':\n        def lambda_rule(epoch):\n            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)\n            return lr_l\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    elif opt.lr_policy == 'step':\n        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.5)\n    elif opt.lr_policy == 'step2':\n        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)\n    elif opt.lr_policy == 'plateau':\n        print('schedular=plateau')\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.01, patience=5)\n    elif opt.lr_policy == 'plateau2':\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)\n    elif opt.lr_policy == 'step_warmstart':\n        def lambda_rule(epoch):\n            #print(epoch)\n            if epoch < 5:\n                lr_l = 0.1\n            elif 5 <= epoch < 100:\n                lr_l = 1\n            elif 100 <= epoch < 200:\n                lr_l = 0.1\n            elif 200 <= epoch:\n                lr_l = 0.01\n            return lr_l\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    elif opt.lr_policy == 'step_warmstart2':\n        def lambda_rule(epoch):\n            #print(epoch)\n            if epoch < 5:\n                lr_l = 0.1\n            elif 5 <= epoch < 50:\n                lr_l = 1\n            elif 50 <= epoch < 100:\n                lr_l = 0.1\n            elif 100 <= epoch:\n                lr_l = 0.01\n            return lr_l\n        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n    else:\n\n        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)\n    return scheduler\n\n\ndef define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):\n    netG = None\n    use_gpu = len(gpu_ids) > 0\n    norm_layer = get_norm_layer(norm_type=norm)\n\n    if use_gpu:\n        assert(torch.cuda.is_available())\n\n    if which_model_netG == 'resnet_9blocks':\n        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)\n    elif which_model_netG == 'resnet_6blocks':\n        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)\n    elif which_model_netG == 'unet_128':\n        netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)\n    elif which_model_netG == 'unet_256':\n        netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)\n    else:\n        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)\n    if len(gpu_ids) > 0:\n        netG.cuda(gpu_ids[0])\n    init_weights(netG, init_type=init_type)\n    return netG\n\n\ndef define_D(input_nc, ndf, which_model_netD,\n             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):\n    netD = None\n    use_gpu = len(gpu_ids) > 0\n    norm_layer = get_norm_layer(norm_type=norm)\n\n    if use_gpu:\n        assert(torch.cuda.is_available())\n    if which_model_netD == 'basic':\n        netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)\n    elif which_model_netD == 'n_layers':\n        netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)\n    else:\n        raise NotImplementedError('Discriminator model name [%s] is not recognized' %\n                                  which_model_netD)\n    if use_gpu:\n        netD.cuda(gpu_ids[0])\n    init_weights(netD, init_type=init_type)\n    return netD\n\n\ndef print_network(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print(net)\n    print('Total number of parameters: %d' % num_params)\n\n\ndef get_n_parameters(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    return num_params\n\n\ndef measure_fp_bp_time(model, x, y):\n    # synchronize gpu time and measure fp\n    torch.cuda.synchronize()\n    t0 = time.time()\n    y_pred = model(x)\n    torch.cuda.synchronize()\n    elapsed_fp = time.time() - t0\n\n    if isinstance(y_pred, tuple):\n        y_pred = sum(y_p.sum() for y_p in y_pred)\n    else:\n        y_pred = y_pred.sum()\n\n    # zero gradients, synchronize time and measure\n    model.zero_grad()\n    t0 = time.time()\n    #y_pred.backward(y)\n    y_pred.backward()\n    torch.cuda.synchronize()\n    elapsed_bp = time.time() - t0\n    return elapsed_fp, elapsed_bp\n\n\ndef benchmark_fp_bp_time(model, x, y, n_trial=1000):\n    # transfer the model on GPU\n    model.cuda()\n\n    # DRY RUNS\n    for i in range(10):\n        _, _ = measure_fp_bp_time(model, x, y)\n\n    print('DONE WITH DRY RUNS, NOW BENCHMARKING')\n    \n    # START BENCHMARKING\n    t_forward = []\n    t_backward = []\n    \n    print('trial: {}'.format(n_trial))\n    for i in range(n_trial):\n        t_fp, t_bp = measure_fp_bp_time(model, x, y)\n        t_forward.append(t_fp)\n        t_backward.append(t_bp)\n\n    # free memory\n    del model\n\n    return np.mean(t_forward), np.mean(t_backward)\n\n##############################################################################\n# Classes\n##############################################################################\n\n\n# Defines the GAN loss which uses either LSGAN or the regular GAN.\n# When LSGAN is used, it is basically same as MSELoss,\n# but it abstracts away the need to create the target label tensor\n# that has the same size as the input\nclass GANLoss(nn.Module):\n    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,\n                 tensor=torch.FloatTensor):\n        super(GANLoss, self).__init__()\n        self.real_label = target_real_label\n        self.fake_label = target_fake_label\n        self.real_label_var = None\n        self.fake_label_var = None\n        self.Tensor = tensor\n        if use_lsgan:\n            self.loss = nn.MSELoss()\n        else:\n            self.loss = nn.BCELoss()\n\n    def get_target_tensor(self, input, target_is_real):\n        target_tensor = None\n        if target_is_real:\n            create_label = ((self.real_label_var is None) or\n                            (self.real_label_var.numel() != input.numel()))\n            if create_label:\n                real_tensor = self.Tensor(input.size()).fill_(self.real_label)\n                self.real_label_var = Variable(real_tensor, requires_grad=False)\n            target_tensor = self.real_label_var\n        else:\n            create_label = ((self.fake_label_var is None) or\n                            (self.fake_label_var.numel() != input.numel()))\n            if create_label:\n                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)\n                self.fake_label_var = Variable(fake_tensor, requires_grad=False)\n            target_tensor = self.fake_label_var\n        return target_tensor\n\n    def __call__(self, input, target_is_real):\n        target_tensor = self.get_target_tensor(input, target_is_real)\n        return self.loss(input, target_tensor)\n\n\n# Defines the generator that consists of Resnet blocks between a few\n# downsampling/upsampling operations.\n# Code and idea originally from Justin Johnson's architecture.\n# https://github.com/jcjohnson/fast-neural-style/\nclass ResnetGenerator(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):\n        assert(n_blocks >= 0)\n        super(ResnetGenerator, self).__init__()\n        self.input_nc = input_nc\n        self.output_nc = output_nc\n        self.ngf = ngf\n        self.gpu_ids = gpu_ids\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        model = [nn.ReflectionPad2d(3),\n                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,\n                           bias=use_bias),\n                 norm_layer(ngf),\n                 nn.ReLU(True)]\n\n        n_downsampling = 2\n        for i in range(n_downsampling):\n            mult = 2**i\n            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,\n                                stride=2, padding=1, bias=use_bias),\n                      norm_layer(ngf * mult * 2),\n                      nn.ReLU(True)]\n\n        mult = 2**n_downsampling\n        for i in range(n_blocks):\n            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n        for i in range(n_downsampling):\n            mult = 2**(n_downsampling - i)\n            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n                                         kernel_size=3, stride=2,\n                                         padding=1, output_padding=1,\n                                         bias=use_bias),\n                      norm_layer(int(ngf * mult / 2)),\n                      nn.ReLU(True)]\n        model += [nn.ReflectionPad2d(3)]\n        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n        model += [nn.Tanh()]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)\n        else:\n            return self.model(input)\n\n\n# Define a resnet block\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        super(ResnetBlock, self).__init__()\n        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)\n\n    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        conv_block = []\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim),\n                       nn.ReLU(True)]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim)]\n\n        return nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        out = x + self.conv_block(x)\n        return out\n\n\n# Defines the Unet generator.\n# |num_downs|: number of downsamplings in UNet. For example,\n# if |num_downs| == 7, image of size 128x128 will become of size 1x1\n# at the bottleneck\nclass UnetGenerator(nn.Module):\n    def __init__(self, input_nc, output_nc, num_downs, ngf=64,\n                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):\n        super(UnetGenerator, self).__init__()\n        self.gpu_ids = gpu_ids\n\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)\n        for i in range(num_downs - 5):\n            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)\n        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)\n\n        self.model = unet_block\n\n    def forward(self, input):\n        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)\n        else:\n            return self.model(input)\n\n\n# Defines the submodule with skip connection.\n# X -------------------identity---------------------- X\n#   |-- downsampling -- |submodule| -- upsampling --|\nclass UnetSkipConnectionBlock(nn.Module):\n    def __init__(self, outer_nc, inner_nc, input_nc=None,\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):\n        super(UnetSkipConnectionBlock, self).__init__()\n        self.outermost = outermost\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n        if input_nc is None:\n            input_nc = outer_nc\n        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,\n                             stride=2, padding=1, bias=use_bias)\n        downrelu = nn.LeakyReLU(0.2, True)\n        downnorm = norm_layer(inner_nc)\n        uprelu = nn.ReLU(True)\n        upnorm = norm_layer(outer_nc)\n\n        if outermost:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1)\n            down = [downconv]\n            up = [uprelu, upconv, nn.Tanh()]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv]\n            up = [uprelu, upconv, upnorm]\n            model = down + up\n        else:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv, downnorm]\n            up = [uprelu, upconv, upnorm]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:\n            return torch.cat([x, self.model(x)], 1)\n\n\n# Defines the PatchGAN discriminator with the specified arguments.\nclass NLayerDiscriminator(nn.Module):\n    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):\n        super(NLayerDiscriminator, self).__init__()\n        self.gpu_ids = gpu_ids\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [\n            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):\n            nf_mult_prev = nf_mult\n            nf_mult = min(2**n, 8)\n            sequence += [\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2**n_layers, 8)\n        sequence += [\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]\n\n        if use_sigmoid:\n            sequence += [nn.Sigmoid()]\n\n        self.model = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)\n        else:\n            return self.model(input)\n"
  },
  {
    "path": "README.md",
    "content": "## CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation\nThis repository provides the code for \"CA-Net: Comprehensive attention Convolutional Neural Networks for Explainable Medical Image Segmentation\". Our work now is available on [Arxiv][paper_link]. Our work is accepted by [TMI][tmi_link].\n\n[paper_link]:https://arxiv.org/pdf/2009.10549.pdf\n\n[tmi_link]:https://ieeexplore.ieee.org/document/9246575\n\n![mg_net](./pictures/canet_framework.png)\nFig. 1. Structure of CA-Net.\n\n![uncertainty](./pictures/skin_results.png)\nFig. 2. Skin lesion segmentation.\n\n![refinement](./pictures/fetal_mri_results.png)\n\nFig. 3. Placenta and fetal brain segmentation.\n\n### Requirementss\nSome important required packages include:\n* [Pytorch][torch_link] version >=0.4.1.\n* Visdom\n* Python == 3.7 \n* Some basic python packages such as Numpy.\n\nFollow official guidance to install [Pytorch][torch_link].\n\n[torch_link]:https://pytorch.org/\n\n## Usages\n### For skin lesion segmentation\n1. First, you can download the dataset at [ISIC 2018][data_link]. We only used ISIC 2018 task1 training dataset, To preprocess the dataset and save as \".npy\", run:\n\n[data_link]:https://challenge.isic-archive.com/data#2018\n\n```\npython isic_preprocess.py \n```\n2. For conducting 5-fold cross-validation, split the preprocessed data into 5 fold and save their filenames. run:\n```\npython create_folder.py \n```\n\n\n2. To train CA-Net in ISIC 2018 (taking 1st-fold validation for example), run:\n```\npython main.py --data ISIC2018 --val_folder folder1 --id Comp_Atten_Unet\n```\n\n3. To evaluate the trained model in ISIC 2018 (we added a test data in folder0, testing the 0th-fold validation for example), run:\n```\npython validation.py --data ISIC2018 --val_folder folder0 --id Comp_Atten_Unet\n```\nOur experimental results are shown in the table:\n![refinement](./pictures/skin_segmentation_results_table.png)\n\n4. You can save the attention weight map in the middle step of the network to '/result' folder. Visualizing the attention weight above the original images, run:\n```\npython show_fused_heatmap.py\n```\nVisualzation of spatial attention weight map:\n![refinement](./pictures/spatial_atten_weight.png)\n\nVisualzation of scale attention weight map:\n![refinement](./pictures/scale_atten_weight.png)\n## Citation\nIf you find our work is helpful for your research, please consider to cite:\n```\n@article{gu2020net,\n  title={CA-Net: Comprehensive Attention Convolutional Neural Networks for Explainable Medical Image Segmentation},\n  author={Gu, Ran and Wang, Guotai and Song, Tao and Huang, Rui and Aertsen, Michael and Deprest, Jan and Ourselin, S{\\'e}bastien and Vercauteren, Tom and Zhang, Shaoting},\n  journal={IEEE Transactions on Medical Imaging},\n  year={2020},\n  publisher={IEEE}\n}\n```\n## Acknowledgement\nPart of the code is revised from [Attention-Gate-Networks][AG].\n\n[AG]:https://github.com/ozan-oktay/Attention-Gated-Networks\n"
  },
  {
    "path": "create_folder.py",
    "content": "import os\nimport numpy\nfrom random import shuffle\n\nPATH = './data/ISIC2018_Task1_npy_all/image'\nSAVE_PATH = './Datasets'\n\n\ndef create_5_floder(folder, save_foler):\n    file_list = os.listdir(folder)\n    shuffle(file_list)\n\n    for i in range(5):\n        if i != 0:\n            pre_test_list = file_list[0:i*518]\n        else:\n            pre_test_list = []\n        test_list = file_list[i*518:(i+1)*518]\n\n        if i < 4:\n            valid_list = file_list[(i+1)*518:(i+1)*518+260]\n            train_list = file_list[(i+1)*518+260:] + pre_test_list\n        else:\n            valid_list = file_list[-4:] + file_list[:256]\n            train_list = file_list[256:i*518]\n\n        if not os.path.isdir(save_foler + '/folder'+str(i+1)):\n            os.makedirs(save_foler + '/folder'+str(i+1))\n\n        text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_train.list'), train_list)\n        text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_validation.list'), valid_list)\n        text_save(os.path.join(save_foler, 'folder'+str(i+1), 'folder'+str(i+1)+'_test.list'), test_list)\n\n\ndef text_save(filename, data):      # filename: path to write CSV, data: data list to be written.\n    file = open(filename, 'w+')\n    for i in range(len(data)):\n        s = str(data[i]).replace('[', '').replace(']', '')\n        s = s.replace(\"'\", '').replace(',', '') + '\\n'\n        file.write(s)\n    file.close()\n    print(\"Save {} successfully\".format(filename.split('/')[-1]))\n\n\nif __name__ == \"__main__\":\n    create_5_floder(PATH, SAVE_PATH)\n"
  },
  {
    "path": "isic_preprocess.py",
    "content": "#!/usr/bin/python3\n# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection\n# -*- coding: utf-8 -*-\n# @Author  : Ran Gu\n\nimport os\nimport random\nimport numpy as np\nfrom skimage import io\nfrom PIL import Image\n\nroot_dir = 'gr/Skin Segmentation'                # change it in your saved original data path\nsave_dir = './data/ISIC2018_Task1_npy_all'\n\n\nif __name__ == '__main__':\n    imgfile = os.path.join(root_dir, 'ISIC2018_Task1-2_Training_Input')\n    labfile = os.path.join(root_dir, 'ISIC2018_Task1_Training_GroundTruth')\n    filename = sorted([os.path.join(imgfile, x) for x in os.listdir(imgfile) if x.endswith('.jpg')])\n    random.shuffle(filename)\n    labname = [filename[x].replace('ISIC2018_Task1-2_Training_Input', 'ISIC2018_Task1_Training_GroundTruth'\n                                   ).replace('.jpg', '_segmentation.png') for x in range(len(filename))]\n\n    if not os.path.isdir(save_dir):\n        os.makedirs(save_dir+'/image')\n        os.makedirs(save_dir+'/label')\n\n    for i in range(len(filename)):\n        fname = filename[i].rsplit('/', maxsplit=1)[-1].split('.')[0]\n        lname = labname[i].rsplit('/', maxsplit=1)[-1].split('.')[0]\n\n        image = Image.open(filename[i])\n        label = Image.open(labname[i])\n\n        image = image.resize((342, 256))\n        label = label.resize((342, 256))\n        image = np.array(image)\n        label = np.array(label)\n\n        images_img_filename = os.path.join(save_dir, 'image', fname)\n        labels_img_filename = os.path.join(save_dir, 'label', lname)\n        np.save(images_img_filename, image)\n        np.save(labels_img_filename, label)\n    print('Successfully saved preprocessed data')\n"
  },
  {
    "path": "main.py",
    "content": "#!/usr/bin/python3\n# these code is for ISIC 2018: Skin Lesion Analysis Towards Melanoma Detection\n# -*- coding: utf-8 -*-\n# @Author  : Ran Gu\nimport os\nimport torch\nimport math\nimport visdom\nimport torch.utils.data as Data\nimport argparse\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom distutils.version import LooseVersion\nfrom Datasets.ISIC2018 import ISIC2018_dataset\nfrom utils.transform import ISIC2018_transform\n\nfrom Models.networks.network import Comprehensive_Atten_Unet\n\nfrom utils.dice_loss import SoftDiceLoss, get_soft_label, val_dice_fetus, val_dice_isic\nfrom utils.dice_loss import Intersection_over_Union_fetus, Intersection_over_Union_isic\n\nfrom utils.evaluation import AverageMeter\nfrom utils.binary import assd\nfrom torch.optim.lr_scheduler import StepLR\n\n\nTest_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}\n\nTest_Dataset = {'ISIC2018': ISIC2018_dataset}\n\nTest_Transform = {'ISIC2018': ISIC2018_transform}\n\n\ndef train(train_loader, model, criterion, optimizer, args, epoch):\n    losses = AverageMeter()\n\n    model.train()\n    for step, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):\n        image = x.float().cuda()\n        target = y.float().cuda()\n\n        output = model(image)                                      # model output\n\n        target_soft = get_soft_label(target, args.num_classes)     # get soft label\n        loss = criterion(output, target_soft, args.num_classes)    # the dice losses\n        losses.update(loss.data, image.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if step % (math.ceil(float(len(train_loader.dataset))/args.batch_size)) == 0:\n            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {losses.avg:.6f}'.format(\n                epoch, step * len(image), len(train_loader.dataset),\n                100. * step / len(train_loader), losses=losses))\n\n    print('The average loss:{losses.avg:.4f}'.format(losses=losses))\n    return losses.avg\n\n\ndef valid_fetus(valid_loader, model, criterion, optimizer, args, epoch, minloss):\n    val_losses = AverageMeter()\n    val_placenta_dice = AverageMeter()\n    val_brain_dice = AverageMeter()\n\n    model.eval()\n    for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):\n        image = t.float().cuda()\n        target = k.float().cuda()\n\n        output = model(image)                                               # model output\n        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)\n        output_soft = get_soft_label(output_dis, args.num_classes)          # get soft label\n        target_soft = get_soft_label(target, args.num_classes)\n\n        val_loss = criterion(output, target_soft, args.num_classes)                       # the dice losses\n        val_losses.update(val_loss.data, image.size(0))\n\n        placenta, brain = val_dice_fetus(output_soft, target_soft, args.num_classes)      # the dice score\n        val_placenta_dice.update(placenta.data, image.size(0))\n        val_brain_dice.update(brain.data, image.size(0))\n\n        if step % (math.ceil(float(len(valid_loader.dataset))/args.batch_size)) == 0:\n            print('Valid Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {losses.avg:.6f}'.format(\n                epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader), losses=val_losses))\n\n    print('The Placenta Mean Average Dice score: {placenta.avg: .4f}; '\n          'The Brain Mean Average Dice score: {brain.avg: .4f}; '\n          'The Average Loss score: {loss.avg: .4f}'.format(\n           placenta=val_placenta_dice, brain=val_brain_dice, loss=val_losses))\n\n    if val_losses.avg < min(minloss):\n        minloss.append(val_losses.avg)\n        print(minloss)\n        modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'\n        print('the best model will be saved at {}'.format(modelname))\n        state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}\n        torch.save(state, modelname)\n\n    return val_losses.avg, val_placenta_dice.avg, val_brain_dice.avg\n\n\ndef valid_isic(valid_loader, model, criterion, optimizer, args, epoch, minloss):\n    val_losses = AverageMeter()\n    val_isic_dice = AverageMeter()\n\n    model.eval()\n    for step, (t, k) in tqdm(enumerate(valid_loader), total=len(valid_loader)):\n        image = t.float().cuda()\n        target = k.float().cuda()\n\n        output = model(image)                                             # model output\n        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)\n        output_soft = get_soft_label(output_dis, args.num_classes)\n        target_soft = get_soft_label(target, args.num_classes)            # get soft label\n\n        val_loss = criterion(output, target_soft, args.num_classes)       # the dice losses\n        val_losses.update(val_loss.data, image.size(0))\n\n        isic = val_dice_isic(output_soft, target_soft, args.num_classes)  # the dice score\n        val_isic_dice.update(isic.data, image.size(0))\n\n        if step % (math.ceil(float(len(valid_loader.dataset)) / args.batch_size)) == 0:\n            print('Valid Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {losses.avg:.6f}'.format(\n                epoch, step * len(image), len(valid_loader.dataset), 100. * step / len(valid_loader),\n                losses=val_losses))\n\n    print('The ISIC Mean Average Dice score: {isic.avg: .4f}; '\n          'The Average Loss score: {loss.avg: .4f}'.format(\n           isic=val_isic_dice, loss=val_losses))\n\n    if val_losses.avg < min(minloss):\n        minloss.append(val_losses.avg)\n        print(minloss)\n        modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'\n        print('the best model will be saved at {}'.format(modelname))\n        state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}\n        torch.save(state, modelname)\n\n    return val_losses.avg, val_isic_dice.avg\n\n\ndef test_fetus(test_loader, model, args):\n    placenta_dice = []\n    brain_dice = []\n    placenta_iou = []\n    brain_iou = []\n    placenta_assd = []\n    brain_assd = []\n\n    modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'\n    if os.path.isfile(modelname):\n        print(\"=> Loading checkpoint '{}'\".format(modelname))\n        checkpoint = torch.load(modelname)\n        # start_epoch = checkpoint['epoch']\n        model.load_state_dict(checkpoint['state_dict'])\n        # optimizer.load_state_dict(checkpoint['opt_dict'])\n        print(\"=> Loaded saved the best model at (epoch {})\".format(checkpoint['epoch']))\n    else:\n        print(\"=> No checkpoint found at '{}'\".format(modelname))\n\n    model.eval()\n    for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):\n        image = img.float().cuda()\n        target = lab.float().cuda()\n\n        output = model(image)                                   # model output\n        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)\n        output_soft = get_soft_label(output_dis, args.num_classes)\n        target_soft = get_soft_label(target, args.num_classes)  # get soft label\n\n        # input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)\n        label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)\n        output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)\n\n        placenta_b_dice, brain_b_dice = val_dice_fetus(output_soft, target_soft, args.num_classes)               # the dice accuracy\n        placenta_b_iou, brain_b_iou = Intersection_over_Union_fetus(output_soft, target_soft, args.num_classes)  # the iou accuracy\n        placenta_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])\n        brain_b_asd = assd(output_arr[:, :, :, 2], label_arr[:, :, :, 2])\n\n        pla_dice_np = placenta_b_dice.data.cpu().numpy()\n        bra_iou_np = brain_b_iou.data.cpu().numpy()\n        bra_dice_np = brain_b_dice.data.cpu().numpy()\n        pla_iou_np = placenta_b_iou.data.cpu().numpy()\n        placenta_dice.append(pla_dice_np)\n        brain_dice.append(bra_dice_np)\n        placenta_iou.append(pla_iou_np)\n        brain_iou.append(bra_iou_np)\n        placenta_assd.append(placenta_b_asd)\n        brain_assd.append(brain_b_asd)\n\n    placenta_dice_mean = np.average(placenta_dice)\n    placenta_dice_std = np.std(placenta_dice)\n    brain_dice_mean = np.average(brain_dice)\n    brain_dice_std = np.std(brain_dice)\n\n    placenta_iou_mean = np.average(placenta_iou)\n    placenta_iou_std = np.std(placenta_iou)\n    brain_iou_mean = np.average(brain_iou)\n    brain_iou_std = np.std(brain_iou)\n\n    placenta_assd_mean = np.average(placenta_assd)\n    placenta_assd_std = np.std(placenta_assd)\n    brain_assd_mean = np.average(brain_assd)\n    brain_assd_std = np.std(brain_assd)\n\n    print('The Placenta mean Accuracy: {placenta_dice_mean: .4f}; The Placenta Accuracy std: {placenta_dice_std: .4f}; '\n          'The Brain mean Accuracy: {brain_dice_mean: .4f}; The Brain Accuracy std: {brain_dice_std: .4f}'.format(\n           placenta_dice_mean=placenta_dice_mean, placenta_dice_std=placenta_dice_std,\n           brain_dice_mean=brain_dice_mean, brain_dice_std=brain_dice_std))\n    print('The Placenta mean IoU: {placenta_iou_mean: .4f}; The Placenta IoU std: {placenta_iou_std: .4f}; '\n          'The Brain mean IoU: {brain_iou_mean: .4f}; The Brain IoU std: {brain_iou_std: .4f}'.format(\n           placenta_iou_mean=placenta_iou_mean, placenta_iou_std=placenta_iou_std,\n           brain_iou_mean=brain_iou_mean, brain_iou_std=brain_iou_std))\n    print('The Placenta mean assd: {placenta_asd_mean: .4f}; The Placenta assd std: {placenta_asd_std: .4f}; '\n          'The Brain mean assd: {brain_asd_mean: .4f}; The Brain assd std: {brain_asd_std: .4f}'.format(\n           placenta_asd_mean=placenta_assd_mean, placenta_asd_std=placenta_assd_std,\n           brain_asd_mean=brain_assd_mean, brain_asd_std=brain_assd_std))\n\n\ndef test_isic(test_loader, model, args):\n    isic_dice = []\n    isic_iou = []\n    isic_assd = []\n\n    modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'\n    if os.path.isfile(modelname):\n        print(\"=> Loading checkpoint '{}'\".format(modelname))\n        checkpoint = torch.load(modelname)\n        # start_epoch = checkpoint['epoch']\n        model.load_state_dict(checkpoint['state_dict'])\n        # optimizer.load_state_dict(checkpoint['opt_dict'])\n        print(\"=> Loaded saved the best model at (epoch {})\".format(checkpoint['epoch']))\n    else:\n        print(\"=> No checkpoint found at '{}'\".format(modelname))\n\n    model.eval()\n    for step, (img, lab) in tqdm(enumerate(test_loader), total=len(test_loader)):\n        image = img.float().cuda()\n        target = lab.float().cuda()\n\n        output = model(image)                                   # model output\n        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)\n        output_soft = get_soft_label(output_dis, args.num_classes)\n        target_soft = get_soft_label(target, args.num_classes)  # get soft label\n\n        label_arr = np.squeeze(target_soft.cpu().numpy()).astype(np.uint8)\n        output_arr = np.squeeze(output_soft.cpu().byte().numpy()).astype(np.uint8)\n\n        isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes)                # the dice accuracy\n        isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes)  # the iou accuracy\n        isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])                       # the assd\n\n        dice_np = isic_b_dice.data.cpu().numpy()\n        iou_np = isic_b_iou.data.cpu().numpy()\n        isic_dice.append(dice_np)\n        isic_iou.append(iou_np)\n        isic_assd.append(isic_b_asd)\n\n    isic_dice_mean = np.average(isic_dice)\n    isic_dice_std = np.std(isic_dice)\n\n    isic_iou_mean = np.average(isic_iou)\n    isic_iou_std = np.std(isic_iou)\n\n    isic_assd_mean = np.average(isic_assd)\n    isic_assd_std = np.std(isic_assd)\n    print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The Placenta Accuracy std: {isic_dice_std: .4f}'.format(\n           isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))\n    print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(\n           isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))\n    print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(\n           isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))\n\n\ndef main(args):\n    minloss = [1.0]\n    start_epoch = args.start_epoch\n\n    # loading the dataset\n    print('loading the {0},{1},{2} dataset ...'.format('train', 'validation', 'test'))\n    trainset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='train',\n                                       transform=Test_Transform[args.data])\n    validset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='validation',\n                                       transform=Test_Transform[args.data])\n    testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test',\n                                      transform=Test_Transform[args.data])\n\n    trainloader = Data.DataLoader(dataset=trainset, batch_size=args.batch_size, shuffle=True, pin_memory=True)\n    validloader = Data.DataLoader(dataset=validset, batch_size=args.batch_size, shuffle=True, pin_memory=True)\n    testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n    print('Loading is done\\n')\n\n    # Define model\n    if args.data == 'Fetus':\n        args.num_input = 1\n        args.num_classes = 3\n        args.out_size = (256, 256)\n    elif args.data == 'ISIC2018':\n        args.num_input = 3\n        args.num_classes = 2\n        args.out_size = (224, 300)\n    model = Test_Model[args.id](args, args.num_input, args.num_classes)\n\n    if torch.cuda.is_available():\n        print('We can use', torch.cuda.device_count(), 'GPUs to train the network')\n        model = model.cuda()\n        # model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))\n\n    # collect the number of parameters in the network\n    print(\"------------------------------------------\")\n    print(\"Network Architecture of Model AttU_Net:\")\n    num_para = 0\n    for name, param in model.named_parameters():\n        num_mul = 1\n        for x in param.size():\n            num_mul *= x\n        num_para += num_mul\n    print(model)\n    print(\"Number of trainable parameters {0} in Model {1}\".format(num_para, args.id))\n    print(\"------------------------------------------\")\n\n    # Define optimizers and loss function\n    optimizer = torch.optim.Adam(model.parameters(),\n                                 lr=args.lr_rate,\n                                 weight_decay=args.weight_decay)    # optimize all model parameters\n    criterion = SoftDiceLoss()\n    scheduler = StepLR(optimizer, step_size=256, gamma=0.5)\n\n    # resume\n    if args.resume:\n        if os.path.isfile(args.resume):\n            print(\"=> Loading checkpoint '{}'\".format(args.resume))\n            checkpoint = torch.load(args.resume)\n            start_epoch = checkpoint['epoch']\n            model.load_state_dict(checkpoint['state_dict'])\n            optimizer.load_state_dict(checkpoint['opt_dict'])\n            print(\"=> Loaded checkpoint (epoch {})\".format(checkpoint['epoch']))\n        else:\n            print(\"=> No checkpoint found at '{}'\".format(args.resume))\n    # visualiser\n    vis = visdom.Visdom(env='CA-net')\n\n    print(\"Start training ...\")\n    for epoch in range(start_epoch + 1, args.epochs + 1):\n        scheduler.step()\n        train_avg_loss = train(trainloader, model, criterion, optimizer, args, epoch)\n        vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([train_avg_loss]),\n                 win=args.id + args.data,\n                 update='append',\n                 opts=dict(title=args.id+'_'+args.data,\n                           xlabel='Epochs',\n                           ylabel='Train_avg_loss'))\n\n        if args.data == 'Fetus':\n            val_avg_loss, val_placenta_dice, val_brain_dice = valid_fetus(validloader, model, criterion,\n                                                                          optimizer, args, epoch, minloss)\n            vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),\n                     win=args.id + args.data + 'valid_avg',\n                     name='loss',\n                     update='append',\n                     opts=dict(title=args.id + '_' + args.data,\n                               xlabel='Epochs',\n                               ylabel='Dice&loss'))\n            vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_placenta_dice]),\n                     win=args.id + args.data + 'valid_avg',\n                     name='placenta_dice',\n                     update='append',\n                     opts=dict(title=args.id + '_' + args.data,\n                               xlabel='Epochs',\n                               ylabel='Dice&loss'))\n            vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_brain_dice]),\n                     win=args.id + args.data + 'valid_avg',\n                     name='brain_dice',\n                     update='append',\n                     opts=dict(title=args.id + '_' + args.data,\n                               xlabel='Epochs',\n                               ylabel='Dice&loss'))\n\n        elif args.data == 'ISIC2018':\n            val_avg_loss, val_isic_dice = valid_isic(validloader, model, criterion, optimizer, args, epoch, minloss)\n            vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_avg_loss]),\n                     win=args.id + args.data + 'valid_avg',\n                     name='loss',\n                     update='append',\n                     opts=dict(title=args.id + '_' + args.data + '_',\n                               xlabel='Epochs',\n                               ylabel='Dice&loss'))\n            vis.line(X=torch.Tensor([epoch]), Y=torch.Tensor([val_isic_dice]),\n                     win=args.id + args.data + 'valid_avg',\n                     name='isic_dice',\n                     update='append',\n                     opts=dict(title=args.id + '_' + args.data,\n                               xlabel='Epochs',\n                               ylabel='Dice&loss'))\n        # save models\n        if epoch > args.particular_epoch:\n            if epoch % args.save_epochs_steps == 0:\n                filename = args.ckpt + '/' + str(epoch) + '_' + args.data + '_checkpoint.pth.tar'\n                print('the model will be saved at {}'.format(filename))\n                state = {'epoch': epoch, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict()}\n                torch.save(state, filename)\n\n    print('Training Done! Start testing')\n    if args.data == 'Fetus':\n        test_fetus(testloader, model, args)\n    elif args.data == 'ISIC2018':\n        test_isic(testloader, model, args)\n    print('Testing Done!')\n\n\nif __name__ == '__main__':\n    assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \\\n        'PyTorch>=0.4.0 is required'\n\n    parser = argparse.ArgumentParser(description='Comprehensive attention network for biomedical Dataset')\n    # Model related arguments\n    parser.add_argument('--id', default='Comp_Atten_Unet',\n                        help='a name for identitying the model. Choose from the following options: Unet')\n\n    # Path related arguments\n    parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',\n                        help='root directory of data')\n    parser.add_argument('--ckpt', default='./saved_models',\n                        help='folder to output checkpoints')\n\n    # optimization related arguments\n    parser.add_argument('--epochs', type=int, default=300, metavar='N',\n                        help='number of epochs to train (default: 10)')\n    parser.add_argument('--start_epoch', default=0, type=int,\n                        help='epoch to start training. useful if continue from a checkpoint')\n    parser.add_argument('--batch_size', type=int, default=16, metavar='N',\n                        help='input batch size for training (default: 12)')\n    parser.add_argument('--lr_rate', type=float, default=1e-4, metavar='LR',\n                        help='learning rate (default: 0.001)')\n    parser.add_argument('--num_classes', default=2, type=int,\n                        help='number of classes')\n    parser.add_argument('--num_input', default=3, type=int,\n                        help='number of input image for each patient')\n    parser.add_argument('--weight_decay', default=1e-8, type=float, help='weights regularizer')\n    parser.add_argument('--particular_epoch', default=30, type=int,\n                        help='after this number, we will save models more frequently')\n    parser.add_argument('--save_epochs_steps', default=200, type=int,\n                        help='frequency to save models after a particular number of epochs')\n    parser.add_argument('--resume', default='',\n                        help='the checkpoint that resumes from')\n\n    # other arguments\n    parser.add_argument('--data', default='ISIC2018', help='choose the dataset')\n    parser.add_argument('--out_size', default=(224, 300), help='the output image size')\n    parser.add_argument('--val_folder', default='folder0', type=str,\n                        help='which cross validation folder')\n\n    args = parser.parse_args()\n    print(\"Input arguments:\")\n    for key, value in vars(args).items():\n        print(\"{:16} {}\".format(key, value))\n\n    args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)\n    print('Models are saved at %s' % (args.ckpt))\n\n    if not os.path.isdir(args.ckpt):\n        os.makedirs(args.ckpt)\n\n    if args.start_epoch > 1:\n        args.resume = args.ckpt + '/' + str(args.start_epoch) + '_' + args.data + '_checkpoint.pth.tar'\n\n    main(args)\n"
  },
  {
    "path": "show_fused_heatmap.py",
    "content": "import os\nimport cv2\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom PIL import Image\n\n\ndef map_scalar_to_color(x):\n    x_list = [0.0, 0.25, 0.5, 0.75, 1.0]\n    c_list = [[0, 0, 255],\n              [0, 255, 255],\n              [0, 255, 0],\n              [255, 255, 0],\n              [255, 0, 0]]\n    for i in range(len(x_list)):\n        if(x <= x_list[i + 1]):\n            x0 = x_list[i]\n            x1 = x_list[i + 1]\n            c0 = c_list[i]\n            c1 = c_list[i + 1]\n            alpha = (x - x0)/(x1 - x0)\n            c = [c0[j]*(1 - alpha) + c1[j] * alpha for j in range(3)]\n            c = [int(item) for item in c]\n            return tuple(c)\n\n\ndef get_fused_heat_map(image, att):\n    [H, W] = image.size\n    img = Image.new('RGB', image.size, (255, 0, 0))\n    \n    for i in range(H):\n        for j in range(W):\n            p0 = image.getpixel((i,j))\n            alpha = att.getpixel((i,j))\n            p1 = map_scalar_to_color(alpha)\n            alpha = 0.3 + alpha*0.5\n            p = [int(p0[c] * (1 - alpha) + p1[c]*alpha) for c in range(3)]\n            p = tuple(p)\n            img.putpixel((i, j), p)\n    return img\n\n\nif __name__ == \"__main__\":\n    image_name = \"./result/atten_map/ISIC_0015937.jpg\"\n    scalar_name = \"./result/atten_map/25_2_8_wgt\"\n    save_name = \"./result/atten_map/15937_wgt3_fused\"\n\n    img = Image.open(image_name)\n    # img = np.load(image_name)\n    # img = Image.fromarray(np.uint8(img*255))\n    # load the scalar map, and normalize the inteinsty to  0 - 1\n    scl = Image.open(scalar_name).convert('L')\n    scl = np.asarray(scl)\n    scl = cv2.resize(scl, dsize=(img.size[0], img.size[1]), interpolation=cv2.INTER_NEAREST)\n    scl_norm = np.asarray(scl, np.float32)/255\n    scl_norm = Image.fromarray(scl_norm)\n    \n    # convert the scalar map to heat map, and fuse it with the original image\n    img_scl = get_fused_heat_map(img, scl_norm)\n    # img_scl.save(save_name, format='png')\n\n    plt.imshow(img_scl), plt.title('fused result')\n    # plt.colorbar()\n    plt.show()\n"
  },
  {
    "path": "utils/binary.py",
    "content": "# Copyright (C) 2013 Oskar Maier\n# \n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n# \n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n# \n# You should have received a copy of the GNU General Public License\n# along with this program.  If not, see <http://www.gnu.org/licenses/>.\n#\n# author Oskar Maier\n# version r0.1.1\n# since 2014-03-13\n# status Release\n\n# build-in modules\n\n# third-party modules\nimport numpy\nfrom scipy.ndimage import _ni_support\nfrom scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\\\n    generate_binary_structure\nfrom scipy.ndimage.measurements import label, find_objects\nfrom scipy.stats import pearsonr\n\n# own modules\n\n# code\ndef dc(result, reference):\n    r\"\"\"\n    Dice coefficient\n    \n    Computes the Dice coefficient (also known as Sorensen index) between the binary\n    objects in two images.\n    \n    The metric is defined as\n    \n    .. math::\n        \n        DC=\\frac{2|A\\cap B|}{|A|+|B|}\n        \n    , where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects).\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    \n    Returns\n    -------\n    dc : float\n        The Dice coefficient between the object(s) in ```result``` and the\n        object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap).\n        \n    Notes\n    -----\n    This is a real metric. The binary images can therefore be supplied in any order.\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n    \n    intersection = numpy.count_nonzero(result & reference)\n    \n    size_i1 = numpy.count_nonzero(result)\n    size_i2 = numpy.count_nonzero(reference)\n    \n    try:\n        dc = 2. * intersection / float(size_i1 + size_i2)\n    except ZeroDivisionError:\n        dc = 0.0\n    \n    return dc\n\ndef jc(result, reference):\n    \"\"\"\n    Jaccard coefficient\n    \n    Computes the Jaccard coefficient between the binary objects in two images.\n    \n    Parameters\n    ----------\n    result: array_like\n            Input data containing objects. Can be any type but will be converted\n            into binary: background where 0, object everywhere else.\n    reference: array_like\n            Input data containing objects. Can be any type but will be converted\n            into binary: background where 0, object everywhere else.\n\n    Returns\n    -------\n    jc: float\n        The Jaccard coefficient between the object(s) in `result` and the\n        object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap).\n    \n    Notes\n    -----\n    This is a real metric. The binary images can therefore be supplied in any order.\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n    \n    intersection = numpy.count_nonzero(result & reference)\n    union = numpy.count_nonzero(result | reference)\n    \n    jc = float(intersection) / float(union)\n    \n    return jc\n\ndef precision(result, reference):\n    \"\"\"\n    Precison.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    \n    Returns\n    -------\n    precision : float\n        The precision between two binary datasets, here mostly binary objects in images,\n        which is defined as the fraction of retrieved instances that are relevant. The\n        precision is not symmetric.\n    \n    See also\n    --------\n    :func:`recall`\n    \n    Notes\n    -----\n    Not symmetric. The inverse of the precision is :func:`recall`.\n    High precision means that an algorithm returned substantially more relevant results than irrelevant.\n    \n    References\n    ----------\n    .. [1] http://en.wikipedia.org/wiki/Precision_and_recall\n    .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n        \n    tp = numpy.count_nonzero(result & reference)\n    fp = numpy.count_nonzero(result & ~reference)\n    \n    try:\n        precision = tp / float(tp + fp)\n    except ZeroDivisionError:\n        precision = 0.0\n    \n    return precision\n\ndef recall(result, reference):\n    \"\"\"\n    Recall.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    \n    Returns\n    -------\n    recall : float\n        The recall between two binary datasets, here mostly binary objects in images,\n        which is defined as the fraction of relevant instances that are retrieved. The\n        recall is not symmetric.\n    \n    See also\n    --------\n    :func:`precision`\n    \n    Notes\n    -----\n    Not symmetric. The inverse of the recall is :func:`precision`.\n    High recall means that an algorithm returned most of the relevant results.\n    \n    References\n    ----------\n    .. [1] http://en.wikipedia.org/wiki/Precision_and_recall\n    .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n        \n    tp = numpy.count_nonzero(result & reference)\n    fn = numpy.count_nonzero(~result & reference)\n\n    try:\n        recall = tp / float(tp + fn)\n    except ZeroDivisionError:\n        recall = 0.0\n    \n    return recall\n\ndef sensitivity(result, reference):\n    \"\"\"\n    Sensitivity.\n    Same as :func:`recall`, see there for a detailed description.\n    \n    See also\n    --------\n    :func:`specificity` \n    \"\"\"\n    return recall(result, reference)\n\ndef specificity(result, reference):\n    \"\"\"\n    Specificity.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    \n    Returns\n    -------\n    specificity : float\n        The specificity between two binary datasets, here mostly binary objects in images,\n        which denotes the fraction of correctly returned negatives. The\n        specificity is not symmetric.\n    \n    See also\n    --------\n    :func:`sensitivity`\n    \n    Notes\n    -----\n    Not symmetric. The completment of the specificity is :func:`sensitivity`.\n    High recall means that an algorithm returned most of the irrelevant results.\n    \n    References\n    ----------\n    .. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity\n    .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n       \n    tn = numpy.count_nonzero(~result & ~reference)\n    fp = numpy.count_nonzero(result & ~reference)\n\n    try:\n        specificity = tn / float(tn + fp)\n    except ZeroDivisionError:\n        specificity = 0.0\n    \n    return specificity\n\ndef true_negative_rate(result, reference):\n    \"\"\"\n    True negative rate.\n    Same as :func:`specificity`, see there for a detailed description.\n    \n    See also\n    --------\n    :func:`true_positive_rate` \n    :func:`positive_predictive_value`\n    \"\"\"\n    return specificity(result, reference)\n\ndef true_positive_rate(result, reference):\n    \"\"\"\n    True positive rate.\n    Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description.\n    \n    See also\n    --------\n    :func:`positive_predictive_value` \n    :func:`true_negative_rate`\n    \"\"\"\n    return recall(result, reference)\n\ndef positive_predictive_value(result, reference):\n    \"\"\"\n    Positive predictive value.\n    Same as :func:`precision`, see there for a detailed description.\n    \n    See also\n    --------\n    :func:`true_positive_rate`\n    :func:`true_negative_rate`\n    \"\"\"\n    return precision(result, reference)\n\ndef hd(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    Hausdorff Distance.\n    \n    Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two\n    images. It is defined as the maximum surface distance between the objects.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        Note that the connectivity influences the result in the case of the Hausdorff distance.\n        \n    Returns\n    -------\n    hd : float\n        The symmetric Hausdorff Distance between the object(s) in ```result``` and the\n        object(s) in ```reference```. The distance unit is the same as for the spacing of \n        elements along each dimension, which is usually given in mm.\n        \n    See also\n    --------\n    :func:`assd`\n    :func:`asd`\n    \n    Notes\n    -----\n    This is a real metric. The binary images can therefore be supplied in any order.\n    \"\"\"\n    hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()\n    hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()\n    hd = max(hd1, hd2)\n    return hd\n\n\ndef hd95(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    95th percentile of the Hausdorff Distance.\n\n    Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two\n    images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is\n    commonly used in Biomedical Segmentation challenges.\n\n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        Note that the connectivity influences the result in the case of the Hausdorff distance.\n\n    Returns\n    -------\n    hd : float\n        The symmetric Hausdorff Distance between the object(s) in ```result``` and the\n        object(s) in ```reference```. The distance unit is the same as for the spacing of \n        elements along each dimension, which is usually given in mm.\n\n    See also\n    --------\n    :func:`hd`\n\n    Notes\n    -----\n    This is a real metric. The binary images can therefore be supplied in any order.\n    \"\"\"\n    hd1 = __surface_distances(result, reference, voxelspacing, connectivity)\n    hd2 = __surface_distances(reference, result, voxelspacing, connectivity)\n    hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95)\n    return hd95\n\n\ndef assd(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    Average symmetric surface distance.\n    \n    Computes the average symmetric surface distance (ASD) between the binary objects in\n    two images.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.         \n        \n    Returns\n    -------\n    assd : float\n        The average symmetric surface distance between the object(s) in ``result`` and the\n        object(s) in ``reference``. The distance unit is the same as for the spacing of \n        elements along each dimension, which is usually given in mm.\n        \n    See also\n    --------\n    :func:`asd`\n    :func:`hd`\n    \n    Notes\n    -----\n    This is a real metric, obtained by calling and averaging\n    \n    >>> asd(result, reference)\n    \n    and\n    \n    >>> asd(reference, result)\n    \n    The binary images can therefore be supplied in any order.\n    \"\"\"\n    assd = numpy.mean( (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)) )\n    return assd\n\ndef asd(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    Average surface distance metric.\n    \n    Computes the average surface distance (ASD) between the binary objects in two images.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.\n    \n    Returns\n    -------\n    asd : float\n        The average surface distance between the object(s) in ``result`` and the\n        object(s) in ``reference``. The distance unit is the same as for the spacing\n        of elements along each dimension, which is usually given in mm.\n        \n    See also\n    --------\n    :func:`assd`\n    :func:`hd`\n    \n    \n    Notes\n    -----\n    This is not a real metric, as it is directed. See `assd` for a real metric of this.\n    \n    The method is implemented making use of distance images and simple binary morphology\n    to achieve high computational speed.\n    \n    Examples\n    --------\n    The `connectivity` determines what pixels/voxels are considered the surface of a\n    binary object. Take the following binary image showing a cross\n    \n    >>> from scipy.ndimage.morphology import generate_binary_structure\n    >>> cross = generate_binary_structure(2, 1)\n    array([[0, 1, 0],\n           [1, 1, 1],\n           [0, 1, 0]])\n           \n    With `connectivity` set to `1` a 4-neighbourhood is considered when determining the\n    object surface, resulting in the surface\n    \n    .. code-block:: python\n    \n        array([[0, 1, 0],\n               [1, 0, 1],\n               [0, 1, 0]])\n           \n    Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get:\n    \n    .. code-block:: python\n    \n        array([[0, 1, 0],\n               [1, 1, 1],\n               [0, 1, 0]])\n           \n    , as a diagonal connection does no longer qualifies as valid object surface.\n    \n    This influences the  results `asd` returns. Imagine we want to compute the surface\n    distance of our cross to a cube-like object:\n    \n    >>> cube = generate_binary_structure(2, 1)\n    array([[1, 1, 1],\n           [1, 1, 1],\n           [1, 1, 1]])\n           \n    , which surface is, independent of the `connectivity` value set, always\n    \n    .. code-block:: python\n    \n        array([[1, 1, 1],\n               [1, 0, 1],\n               [1, 1, 1]])\n           \n    Using a `connectivity` of `1` we get\n    \n    >>> asd(cross, cube, connectivity=1)\n    0.0\n    \n    while a value of `2` returns us\n    \n    >>> asd(cross, cube, connectivity=2)\n    0.20000000000000001\n    \n    due to the center of the cross being considered surface as well.\n    \n    \"\"\"\n    sds = __surface_distances(result, reference, voxelspacing, connectivity)\n    asd = sds.mean()\n    return asd\n\ndef ravd(result, reference):\n    \"\"\"\n    Relative absolute volume difference.\n    \n    Compute the relative absolute volume difference between the (joined) binary objects\n    in the two images.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n        \n    Returns\n    -------\n    ravd : float\n        The relative absolute volume difference between the object(s) in ``result``\n        and the object(s) in ``reference``. This is a percentage value in the range\n        :math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score.\n        \n    Raises\n    ------\n    RuntimeError\n        If the reference object is empty.\n        \n    See also\n    --------\n    :func:`dc`\n    :func:`precision`\n    :func:`recall`\n    \n    Notes\n    -----\n    This is not a real metric, as it is directed. Negative values denote a smaller\n    and positive values a larger volume than the reference.\n    This implementation does not check, whether the two supplied arrays are of the same\n    size.\n    \n    Examples\n    --------\n    Considering the following inputs\n    \n    >>> import numpy\n    >>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]])\n    >>> arr1\n    array([[0, 1, 0],\n           [1, 1, 1],\n           [0, 1, 0]])\n    >>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]])\n    >>> arr2\n    array([[0, 1, 0],\n           [1, 0, 1],\n           [0, 1, 0]])\n           \n    comparing `arr1` to `arr2` we get\n    \n    >>> ravd(arr1, arr2)\n    -0.2\n    \n    and reversing the inputs the directivness of the metric becomes evident\n    \n    >>> ravd(arr2, arr1)\n    0.25\n    \n    It is important to keep in mind that a perfect score of `0` does not mean that the\n    binary objects fit exactely, as only the volumes are compared:\n    \n    >>> arr1 = numpy.asarray([1,0,0])\n    >>> arr2 = numpy.asarray([0,0,1])\n    >>> ravd(arr1, arr2)\n    0.0\n    \n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n        \n    vol1 = numpy.count_nonzero(result)\n    vol2 = numpy.count_nonzero(reference)\n    \n    if 0 == vol2:\n        raise RuntimeError('The second supplied array does not contain any binary object.')\n    \n    return (vol1 - vol2) / float(vol2)\n\ndef volume_correlation(results, references):\n    r\"\"\"\n    Volume correlation.\n    \n    Computes the linear correlation in binary object volume between the\n    contents of the successive binary images supplied. Measured through\n    the Pearson product-moment correlation coefficient. \n    \n    Parameters\n    ----------\n    results : sequence of array_like\n        Ordered list of input data containing objects. Each array_like will be\n        converted into binary: background where 0, object everywhere else.\n    references : sequence of array_like\n        Ordered list of input data containing objects. Each array_like will be\n        converted into binary: background where 0, object everywhere else.\n        The order must be the same as for ``results``.\n    \n    Returns\n    -------\n    r : float\n        The correlation coefficient between -1 and 1.\n    p : float\n        The two-side p value.\n        \n    \"\"\"\n    results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))\n    references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))\n    \n    results_volumes = [numpy.count_nonzero(r) for r in results]\n    references_volumes = [numpy.count_nonzero(r) for r in references]\n    \n    return pearsonr(results_volumes, references_volumes) # returns (Pearson'\n\ndef volume_change_correlation(results, references):\n    r\"\"\"\n    Volume change correlation.\n    \n    Computes the linear correlation of change in binary object volume between\n    the contents of the successive binary images supplied. Measured through\n    the Pearson product-moment correlation coefficient. \n    \n    Parameters\n    ----------\n    results : sequence of array_like\n        Ordered list of input data containing objects. Each array_like will be\n        converted into binary: background where 0, object everywhere else.\n    references : sequence of array_like\n        Ordered list of input data containing objects. Each array_like will be\n        converted into binary: background where 0, object everywhere else.\n        The order must be the same as for ``results``.\n    \n    Returns\n    -------\n    r : float\n        The correlation coefficient between -1 and 1.\n    p : float\n        The two-side p value.\n        \n    \"\"\"\n    results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool))\n    references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool))\n    \n    results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results])\n    references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references])\n    \n    results_volumes_changes = results_volumes[1:] - results_volumes[:-1]\n    references_volumes_changes = references_volumes[1:] - references_volumes[:-1] \n    \n    return pearsonr(results_volumes_changes, references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value)\n    \ndef obj_assd(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    Average symmetric surface distance.\n    \n    Computes the average symmetric surface distance (ASSD) between the binary objects in\n    two images.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining what accounts\n        for a distinct binary object as well as when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.\n        \n    Returns\n    -------\n    assd : float\n        The average symmetric surface distance between all mutually existing distinct\n        binary object(s) in ``result`` and ``reference``. The distance unit is the same as for\n        the spacing of elements along each dimension, which is usually given in mm.\n        \n    See also\n    --------\n    :func:`obj_asd`\n    \n    Notes\n    -----\n    This is a real metric, obtained by calling and averaging\n    \n    >>> obj_asd(result, reference)\n    \n    and\n    \n    >>> obj_asd(reference, result)\n    \n    The binary images can therefore be supplied in any order.\n    \"\"\"\n    assd = numpy.mean( (obj_asd(result, reference, voxelspacing, connectivity), obj_asd(reference, result, voxelspacing, connectivity)) )\n    return assd\n    \n    \ndef obj_asd(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    Average surface distance between objects.\n    \n    First correspondences between distinct binary objects in reference and result are\n    established. Then the average surface distance is only computed between corresponding\n    objects. Correspondence is defined as unique and at least one voxel overlap.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    voxelspacing : float or sequence of floats, optional\n        The voxelspacing in a distance unit i.e. spacing of elements\n        along each dimension. If a sequence, must be of length equal to\n        the input rank; if a single number, this is used for all axes. If\n        not specified, a grid spacing of unity is implied.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining what accounts\n        for a distinct binary object as well as when determining the surface\n        of the binary objects. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.\n        \n    Returns\n    -------\n    asd : float\n        The average surface distance between all mutually existing distinct binary\n        object(s) in ``result`` and ``reference``. The distance unit is the same as for the\n        spacing of elements along each dimension, which is usually given in mm.\n        \n    See also\n    --------\n    :func:`obj_assd`\n    :func:`obj_tpr`\n    :func:`obj_fpr`\n        \n    Notes\n    -----\n    This is not a real metric, as it is directed. See `obj_assd` for a real metric of this.\n    \n    For the understanding of this metric, both the notions of connectedness and surface\n    distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more\n    information on the first and :func:`asd` on the second.\n        \n    Examples\n    --------\n    >>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]])\n    >>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]])\n    >>> arr1\n    array([[1, 1, 1],\n           [1, 1, 1],\n           [1, 1, 1]])\n    >>> arr2\n    array([[0, 1, 0],\n           [0, 1, 0],\n           [0, 1, 0]])\n    >>> obj_asd(arr1, arr2)\n    1.5\n    >>> obj_asd(arr2, arr1)\n    0.333333333333\n    \n    With the `voxelspacing` parameter, the distances between the voxels can be set for\n    each dimension separately:\n    \n    >>> obj_asd(arr1, arr2, voxelspacing=(1,2))\n    1.5\n    >>> obj_asd(arr2, arr1, voxelspacing=(1,2))\n    0.333333333333    \n    \n    More examples depicting the notion of object connectedness:\n    \n    >>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]])\n    >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])\n    >>> arr1\n    array([[1, 0, 1],\n           [1, 0, 0],\n           [0, 0, 0]])\n    >>> arr2\n    array([[1, 0, 1],\n           [1, 0, 0],\n           [0, 0, 1]])\n    >>> obj_asd(arr1, arr2)\n    0.0\n    >>> obj_asd(arr2, arr1)\n    0.0\n    \n    >>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]])\n    >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]])\n    >>> arr1\n    array([[1, 0, 1],\n           [1, 0, 1],\n           [0, 0, 1]])\n    >>> arr2\n    array([[1, 0, 1],\n           [1, 0, 0],\n           [0, 0, 1]])\n    >>> obj_asd(arr1, arr2)\n    0.6\n    >>> obj_asd(arr2, arr1)\n    0.0\n    \n    Influence of `connectivity` parameter can be seen in the following example, where\n    with the (default) connectivity of `1` the first array is considered to contain two\n    objects, while with an increase connectivity of `2`, just one large object is\n    detected.  \n    \n    >>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]])\n    >>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]])\n    >>> arr1\n    array([[1, 0, 0],\n           [0, 1, 1],\n           [0, 1, 1]])\n    >>> arr2\n    array([[1, 0, 0],\n           [0, 0, 0],\n           [0, 0, 0]])\n    >>> obj_asd(arr1, arr2)\n    0.0\n    >>> obj_asd(arr1, arr2, connectivity=2)\n    1.742955328\n    \n    Note that the connectivity also influence the notion of what is considered an object\n    surface voxels.\n    \"\"\"\n    sds = list()\n    labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity)\n    slicers1 = find_objects(labelmap1)\n    slicers2 = find_objects(labelmap2)\n    for lid2, lid1 in list(mapping.items()):\n        window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1])\n        object1 = labelmap1[window] == lid1\n        object2 = labelmap2[window] == lid2\n        sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity))\n    asd = numpy.mean(sds)\n    return asd\n    \ndef obj_fpr(result, reference, connectivity=1):\n    \"\"\"\n    The false positive rate of distinct binary object detection.\n    \n    The false positive rates gives a percentage measure of how many distinct binary\n    objects in the second array do not exists in the first array. A partial overlap\n    (of minimum one voxel) is here considered sufficient.\n    \n    In cases where two distinct binary object in the second array overlap with a single\n    distinct object in the first array, only one is considered to have been detected\n    successfully and the other is added to the count of false positives.\n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining what accounts\n        for a distinct binary object. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.\n        \n    Returns\n    -------\n    tpr : float\n        A percentage measure of how many distinct binary objects in ``results`` have no\n        corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0`\n        denotes an ideal score.\n        \n    Raises\n    ------\n    RuntimeError\n        If the second array is empty.\n    \n    See also\n    --------\n    :func:`obj_tpr`\n    \n    Notes\n    -----\n    This is not a real metric, as it is directed. Whatever array is considered as\n    reference should be passed second. A perfect score of :math:`0` tells that there are no\n    distinct binary objects in the second array that do not exists also in the reference\n    array, but does not reveal anything about objects in the reference array also\n    existing in the second array (use :func:`obj_tpr` for this).\n    \n    Examples\n    --------\n    >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])\n    >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])\n    >>> arr2\n    array([[1, 0, 0],\n           [1, 0, 1],\n           [0, 0, 1]])\n    >>> arr1\n    array([[0, 0, 1],\n           [1, 0, 1],\n           [0, 0, 1]])\n    >>> obj_fpr(arr1, arr2)\n    0.0\n    >>> obj_fpr(arr2, arr1)\n    0.0\n    \n    Example of directedness:\n    \n    >>> arr2 = numpy.asarray([1,0,1,0,1])\n    >>> arr1 = numpy.asarray([1,0,1,0,0])\n    >>> obj_fpr(arr1, arr2)\n    0.0\n    >>> obj_fpr(arr2, arr1)\n    0.3333333333333333\n    \n    Examples of multiple overlap treatment:\n    \n    >>> arr2 = numpy.asarray([1,0,1,0,1,1,1])\n    >>> arr1 = numpy.asarray([1,1,1,0,1,0,1])\n    >>> obj_fpr(arr1, arr2)\n    0.3333333333333333\n    >>> obj_fpr(arr2, arr1)\n    0.3333333333333333\n    \n    >>> arr2 = numpy.asarray([1,0,1,1,1,0,1])\n    >>> arr1 = numpy.asarray([1,1,1,0,1,1,1])\n    >>> obj_fpr(arr1, arr2)\n    0.0\n    >>> obj_fpr(arr2, arr1)\n    0.3333333333333333\n    \n    >>> arr2 = numpy.asarray([[1,0,1,0,0],\n                              [1,0,0,0,0],\n                              [1,0,1,1,1],\n                              [0,0,0,0,0],\n                              [1,0,1,0,0]])\n    >>> arr1 = numpy.asarray([[1,1,1,0,0],\n                              [0,0,0,0,0],\n                              [1,1,1,0,1],\n                              [0,0,0,0,0],\n                              [1,1,1,0,0]])\n    >>> obj_fpr(arr1, arr2)\n    0.0\n    >>> obj_fpr(arr2, arr1)\n    0.2    \n    \"\"\"\n    _, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)\n    return (n_obj_reference - len(mapping)) / float(n_obj_reference)\n    \ndef obj_tpr(result, reference, connectivity=1):\n    \"\"\"\n    The true positive rate of distinct binary object detection.\n    \n    The true positive rates gives a percentage measure of how many distinct binary\n    objects in the first array also exists in the second array. A partial overlap\n    (of minimum one voxel) is here considered sufficient.\n    \n    In cases where two distinct binary object in the first array overlaps with a single\n    distinct object in the second array, only one is considered to have been detected\n    successfully.  \n    \n    Parameters\n    ----------\n    result : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    reference : array_like\n        Input data containing objects. Can be any type but will be converted\n        into binary: background where 0, object everywhere else.\n    connectivity : int\n        The neighbourhood/connectivity considered when determining what accounts\n        for a distinct binary object. This value is passed to\n        `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`.\n        The decision on the connectivity is important, as it can influence the results\n        strongly. If in doubt, leave it as it is.\n        \n    Returns\n    -------\n    tpr : float\n        A percentage measure of how many distinct binary objects in ``result`` also exists\n        in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score.\n        \n    Raises\n    ------\n    RuntimeError\n        If the reference object is empty.\n    \n    See also\n    --------\n    :func:`obj_fpr`\n    \n    Notes\n    -----\n    This is not a real metric, as it is directed. Whatever array is considered as\n    reference should be passed second. A perfect score of :math:`1` tells that all distinct\n    binary objects in the reference array also exist in the result array, but does not\n    reveal anything about additional binary objects in the result array\n    (use :func:`obj_fpr` for this).\n    \n    Examples\n    --------\n    >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]])\n    >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]])\n    >>> arr2\n    array([[1, 0, 0],\n           [1, 0, 1],\n           [0, 0, 1]])\n    >>> arr1\n    array([[0, 0, 1],\n           [1, 0, 1],\n           [0, 0, 1]])\n    >>> obj_tpr(arr1, arr2)\n    1.0\n    >>> obj_tpr(arr2, arr1)\n    1.0\n    \n    Example of directedness:\n    \n    >>> arr2 = numpy.asarray([1,0,1,0,1])\n    >>> arr1 = numpy.asarray([1,0,1,0,0])\n    >>> obj_tpr(arr1, arr2)\n    0.6666666666666666\n    >>> obj_tpr(arr2, arr1)\n    1.0\n    \n    Examples of multiple overlap treatment:\n    \n    >>> arr2 = numpy.asarray([1,0,1,0,1,1,1])\n    >>> arr1 = numpy.asarray([1,1,1,0,1,0,1])\n    >>> obj_tpr(arr1, arr2)\n    0.6666666666666666\n    >>> obj_tpr(arr2, arr1)\n    0.6666666666666666\n    \n    >>> arr2 = numpy.asarray([1,0,1,1,1,0,1])\n    >>> arr1 = numpy.asarray([1,1,1,0,1,1,1])\n    >>> obj_tpr(arr1, arr2)\n    0.6666666666666666\n    >>> obj_tpr(arr2, arr1)\n    1.0\n    \n    >>> arr2 = numpy.asarray([[1,0,1,0,0],\n                              [1,0,0,0,0],\n                              [1,0,1,1,1],\n                              [0,0,0,0,0],\n                              [1,0,1,0,0]])\n    >>> arr1 = numpy.asarray([[1,1,1,0,0],\n                              [0,0,0,0,0],\n                              [1,1,1,0,1],\n                              [0,0,0,0,0],\n                              [1,1,1,0,0]])\n    >>> obj_tpr(arr1, arr2)\n    0.8\n    >>> obj_tpr(arr2, arr1)\n    1.0    \n    \"\"\"\n    _, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity)\n    return len(mapping) / float(n_obj_result)\n\ndef __distinct_binary_object_correspondences(reference, result, connectivity=1):\n    \"\"\"\n    Determines all distinct (where connectivity is defined by the connectivity parameter\n    passed to scipy's `generate_binary_structure`) binary objects in both of the input\n    parameters and returns a 1to1 mapping from the labelled objects in reference to the\n    corresponding (whereas a one-voxel overlap suffices for correspondence) objects in\n    result.\n    \n    All stems from the problem, that the relationship is non-surjective many-to-many.\n    \n    @return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1)\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n    \n    # binary structure\n    footprint = generate_binary_structure(result.ndim, connectivity)\n    \n    # label distinct binary objects\n    labelmap1, n_obj_result = label(result, footprint)\n    labelmap2, n_obj_reference = label(reference, footprint)\n    \n    # find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing\n    slicers = find_objects(labelmap2) # get windows of labelled objects\n    mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1\n    used_labels = set() # set to collect all already used labels from labelmap2\n    one_to_many = list() # list to collect all one-to-many mappings\n    for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows\n        l1id += 1 # labelled objects have ids sarting from 1\n        bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation\n        l2ids = numpy.unique(labelmap1[slicer][bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping)\n        l2ids = l2ids[0 != l2ids] # remove background identifiers (=0)\n        if 1 == len(l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used\n            l2id = l2ids[0]\n            if not l2id in used_labels:\n                mapping[l1id] = l2id\n                used_labels.add(l2id)\n        elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing\n            one_to_many.append((l1id, set(l2ids)))\n            \n    # process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first\n    while True:\n        one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in one_to_many] # remove already used ids from all sets\n        one_to_many = [x for x in one_to_many if x[1]] # remove empty sets\n        one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length\n        if 0 == len(one_to_many):\n            break\n        l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set\n        mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings \n        used_labels.add(l2id) # mark target label as used\n        one_to_many = one_to_many[1:] # delete the processed set from all sets\n    \n    return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping\n    \ndef __surface_distances(result, reference, voxelspacing=None, connectivity=1):\n    \"\"\"\n    The distances between the surface voxel of binary objects in result and their\n    nearest partner surface voxel of a binary object in reference.\n    \"\"\"\n    result = numpy.atleast_1d(result.astype(numpy.bool))\n    reference = numpy.atleast_1d(reference.astype(numpy.bool))\n    if voxelspacing is not None:\n        voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)\n        voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64)\n        if not voxelspacing.flags.contiguous:\n            voxelspacing = voxelspacing.copy()\n            \n    # binary structure\n    footprint = generate_binary_structure(result.ndim, connectivity)\n    \n    # test for emptiness\n    if 0 == numpy.count_nonzero(result): \n        raise RuntimeError('The first supplied array does not contain any binary object.')\n    if 0 == numpy.count_nonzero(reference): \n        raise RuntimeError('The second supplied array does not contain any binary object.')    \n            \n    # extract only 1-pixel border line of objects\n    result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)\n    reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)\n    \n    # compute average surface distance        \n    # Note: scipys distance transform is calculated only inside the borders of the\n    #       foreground objects, therefore the input has to be reversed\n    dt = distance_transform_edt(~reference_border, sampling=voxelspacing)\n    sds = dt[result_border]\n    \n    return sds\n\ndef __combine_windows(w1, w2):\n    \"\"\"\n    Joins two windows (defined by tuple of slices) such that their maximum\n    combined extend is covered by the new returned window.\n    \"\"\"\n    res = []\n    for s1, s2 in zip(w1, w2):\n        res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop)))\n    return tuple(res)\n"
  },
  {
    "path": "utils/dice_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\n\nclass SoftDiceLoss(_Loss):\n    '''\n    Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)\n    eps is a small constant to avoid zero division,\n    '''\n    def __init__(self, *args, **kwargs):\n        super(SoftDiceLoss, self).__init__()\n\n    def forward(self, prediction, soft_ground_truth, num_class=3, weight_map=None, eps=1e-8):\n        dice_loss = soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map)\n        return dice_loss\n\n\ndef get_soft_label(input_tensor, num_class):\n    \"\"\"\n        convert a label tensor to soft label\n        input_tensor: tensor with shape [N, C, H, W]\n        output_tensor: shape [N, H, W, num_class]\n    \"\"\"\n    tensor_list = []\n    input_tensor = input_tensor.permute(0, 2, 3, 1)\n    for i in range(num_class):\n        temp_prob = torch.eq(input_tensor, i * torch.ones_like(input_tensor))\n        tensor_list.append(temp_prob)\n    output_tensor = torch.cat(tensor_list, dim=-1)\n    output_tensor = output_tensor.float()\n    return output_tensor\n\n\ndef soft_dice_loss(prediction, soft_ground_truth, num_class, weight_map=None):\n    predict = prediction.permute(0, 2, 3, 1)\n    pred = predict.contiguous().view(-1, num_class)\n    # pred = F.softmax(pred, dim=1)\n    ground = soft_ground_truth.view(-1, num_class)\n    n_voxels = ground.size(0)\n    if weight_map is not None:\n        weight_map = weight_map.view(-1)\n        weight_map_nclass = weight_map.repeat(num_class).view_as(pred)\n        ref_vol = torch.sum(weight_map_nclass * ground, 0)\n        intersect = torch.sum(weight_map_nclass * ground * pred, 0)\n        seg_vol = torch.sum(weight_map_nclass * pred, 0)\n    else:\n        ref_vol = torch.sum(ground, 0)\n        intersect = torch.sum(ground * pred, 0)\n        seg_vol = torch.sum(pred, 0)\n    dice_score = (2.0 * intersect + 1e-5) / (ref_vol + seg_vol + 1.0 + 1e-5)\n    # dice_loss = 1.0 - torch.mean(dice_score)\n    # return dice_loss\n    dice_score = torch.mean(-torch.log(dice_score))\n    return dice_score\n\n\ndef val_dice_fetus(prediction, soft_ground_truth, num_class):\n    # predict = prediction.permute(0, 2, 3, 1)\n    pred = prediction.contiguous().view(-1, num_class)\n    # pred = F.softmax(pred, dim=1)\n    ground = soft_ground_truth.view(-1, num_class)\n    ref_vol = torch.sum(ground, 0)\n    intersect = torch.sum(ground * pred, 0)\n    seg_vol = torch.sum(pred, 0)\n    dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)\n    dice_mean_score = torch.mean(dice_score)\n    placenta_dice = dice_score[1]\n    brain_dice = dice_score[2]\n\n    return placenta_dice, brain_dice\n\n\ndef Intersection_over_Union_fetus(prediction, soft_ground_truth, num_class):\n    # predict = prediction.permute(0, 2, 3, 1)\n    pred = prediction.contiguous().view(-1, num_class)\n    # pred = F.softmax(pred, dim=1)\n    ground = soft_ground_truth.view(-1, num_class)\n    ref_vol = torch.sum(ground, 0)\n    intersect = torch.sum(ground * pred, 0)\n    seg_vol = torch.sum(pred, 0)\n    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)\n    dice_mean_score = torch.mean(iou_score)\n    placenta_iou = iou_score[1]\n    brain_iou = iou_score[2]\n\n    return placenta_iou, brain_iou\n\n\ndef val_dice_isic(prediction, soft_ground_truth, num_class):\n    # predict = prediction.permute(0, 2, 3, 1)\n    pred = prediction.contiguous().view(-1, num_class)\n    # pred = F.softmax(pred, dim=1)\n    ground = soft_ground_truth.view(-1, num_class)\n    ref_vol = torch.sum(ground, 0)\n    intersect = torch.sum(ground * pred, 0)\n    seg_vol = torch.sum(pred, 0)\n    dice_score = 2.0 * intersect / (ref_vol + seg_vol + 1.0)\n    dice_mean_score = torch.mean(dice_score)\n\n    return dice_mean_score\n\n\ndef Intersection_over_Union_isic(prediction, soft_ground_truth, num_class):\n    # predict = prediction.permute(0, 2, 3, 1)\n    pred = prediction.contiguous().view(-1, num_class)\n    # pred = F.softmax(pred, dim=1)\n    ground = soft_ground_truth.view(-1, num_class)\n    ref_vol = torch.sum(ground, 0)\n    intersect = torch.sum(ground * pred, 0)\n    seg_vol = torch.sum(pred, 0)\n    iou_score = intersect / (ref_vol + seg_vol - intersect + 1.0)\n    iou_mean_score = torch.mean(iou_score)\n\n    return iou_mean_score\n"
  },
  {
    "path": "utils/evaluation.py",
    "content": "class AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "utils/transform.py",
    "content": "import torch\nimport random\nimport PIL\nimport numbers\nimport numpy as np\nimport torch.nn as nn\nimport collections\nimport matplotlib.pyplot as plt\nimport torchvision.transforms as ts\nimport torchvision.transforms.functional as TF\nfrom PIL import Image, ImageDraw\n\n\n_pil_interpolation_to_str = {\n    Image.NEAREST: 'PIL.Image.NEAREST',\n    Image.BILINEAR: 'PIL.Image.BILINEAR',\n    Image.BICUBIC: 'PIL.Image.BICUBIC',\n    Image.LANCZOS: 'PIL.Image.LANCZOS',\n}\n\n\ndef ISIC2018_transform(sample, train_type):\n    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\\\n                   Image.fromarray(np.uint8(sample['label']), mode='L')\n\n    if train_type == 'train':\n        image, label = randomcrop(size=(224, 300))(image, label)\n        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)\n    else:\n        image, label = resize(size=(224, 300))(image, label)\n\n    image = ts.Compose([ts.ToTensor(),\n                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)\n    label = ts.ToTensor()(label)\n\n    return {'image': image, 'label': label}\n\n\n# these are founctional function for transform\ndef randomflip_rotate(img, lab, p=0.5, degrees=0):\n    if random.random() < p:\n        img = TF.hflip(img)\n        lab = TF.hflip(lab)\n    if random.random() < p:\n        img = TF.vflip(img)\n        lab = TF.vflip(lab)\n\n    if isinstance(degrees, numbers.Number):\n        if degrees < 0:\n            raise ValueError(\"If degrees is a single number, it must be positive.\")\n        degrees = (-degrees, degrees)\n    else:\n        if len(degrees) != 2:\n            raise ValueError(\"If degrees is a sequence, it must be of len 2.\")\n        degrees = degrees\n    angle = random.uniform(degrees[0], degrees[1])\n    img = TF.rotate(img, angle)\n    lab = TF.rotate(lab, angle)\n\n    return img, lab\n\n\nclass randomcrop(object):\n    \"\"\"Crop the given PIL Image and mask at a random location.\n\n    Args:\n        size (sequence or int): Desired output size of the crop. If size is an\n            int instead of sequence like (h, w), a square crop (size, size) is\n            made.\n        padding (int or sequence, optional): Optional padding on each border\n            of the image. Default is 0, i.e no padding. If a sequence of length\n            4 is provided, it is used to pad left, top, right, bottom borders\n            respectively.\n        pad_if_needed (boolean): It will pad the image if smaller than the\n            desired size to avoid raising an exception.\n    \"\"\"\n\n    def __init__(self, size, padding=0, pad_if_needed=False):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n        self.padding = padding\n        self.pad_if_needed = pad_if_needed\n\n    @staticmethod\n    def get_params(img, output_size):\n        \"\"\"Get parameters for ``crop`` for a random crop.\n\n        Args:\n            img (PIL Image): Image to be cropped.\n            output_size (tuple): Expected output size of the crop.\n\n        Returns:\n            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.\n        \"\"\"\n        w, h = img.size\n        th, tw = output_size\n        if w == tw and h == th:\n            return 0, 0, h, w\n\n        i = random.randint(0, h - th)\n        j = random.randint(0, w - tw)\n        return i, j, th, tw\n\n    def __call__(self, img, lab):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped.\n            lab (PIL Image): Image to be cropped.\n\n        Returns:\n            PIL Image: Cropped image and mask.\n        \"\"\"\n        if self.padding > 0:\n            img = TF.pad(img, self.padding)\n            lab = TF.pad(lab, self.padding)\n\n        # pad the width if needed\n        if self.pad_if_needed and img.size[0] < self.size[1]:\n            img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))\n            lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))\n        # pad the height if needed\n        if self.pad_if_needed and img.size[1] < self.size[0]:\n            img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))\n            lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))\n\n        i, j, h, w = self.get_params(img, self.size)\n\n        return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)\n\n\nclass resize(object):\n    \"\"\"Resize the input PIL Image and mask to the given size.\n\n    Args:\n        size (sequence or int): Desired output size. If size is a sequence like\n            (h, w), output size will be matched to this. If size is an int,\n            smaller edge of the image will be matched to this number.\n            i.e, if height > width, then image will be rescaled to\n            (size * height / width, size)\n        interpolation (int, optional): Desired interpolation. Default is\n            ``PIL.Image.BILINEAR``\n    \"\"\"\n\n    def __init__(self, size, interpolation=Image.BILINEAR):\n        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)\n        self.size = size\n        self.interpolation = interpolation\n\n    def __call__(self, img, lab):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be scaled.\n            lab (PIL Image): Image to be scaled.\n\n        Returns:\n            PIL Image: Rescaled image and mask.\n        \"\"\"\n        return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)\n\n    def __repr__(self):\n        interpolate_str = _pil_interpolation_to_str[self.interpolation]\n        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)\n\n\ndef itensity_normalize(volume):\n    \"\"\"\n    normalize the itensity of an nd volume based on the mean and std of nonzeor region\n    inputs:\n        volume: the input nd volume\n    outputs:\n        out: the normalized n                                                                                                                                                                 d volume\n    \"\"\"\n\n    # pixels = volume[volume > 0]\n    mean = volume.mean()\n    std = volume.std()\n    out = (volume - mean) / std\n    out_random = np.random.normal(0, 1, size=volume.shape)\n    out[volume == 0] = out_random[volume == 0]\n\n    return out"
  },
  {
    "path": "validation.py",
    "content": "import os\nimport torch\nimport argparse\nimport numpy as np\nimport pandas as pd\nimport torch.utils.data as Data\nfrom utils.binary import assd\nfrom distutils.version import LooseVersion\n\nfrom Datasets.ISIC2018 import ISIC2018_dataset\nfrom utils.transform import ISIC2018_transform\n\nfrom Models.networks.network import Comprehensive_Atten_Unet\n\nfrom utils.dice_loss import get_soft_label, val_dice_isic\nfrom utils.dice_loss import Intersection_over_Union_isic\n\nfrom time import *\n\nTest_Model = {'Comp_Atten_Unet': Comprehensive_Atten_Unet}\n\nTest_Dataset = {'ISIC2018': ISIC2018_dataset}\n\nTest_Transform = {'ISIC2018': ISIC2018_transform}\n\n\ndef test_isic(test_loader, model):\n    isic_dice = []\n    isic_iou = []\n    isic_assd = []\n    infer_time = []\n\n    model.eval()\n    for step, (img, lab) in enumerate(test_loader):\n        image = img.float().cuda()\n        target = lab.float().cuda()\n\n        # output, atten2_map, atten3_map = model(image)  # model output\n        begin_time = time()\n        output = model(image)\n        end_time = time()\n        pred_time = end_time - begin_time\n        infer_time.append(pred_time)\n\n        output_dis = torch.max(output, 1)[1].unsqueeze(dim=1)\n        output_soft = get_soft_label(output_dis, args.num_classes)\n        target_soft = get_soft_label(target, args.num_classes)  # get soft label\n\n        # input_arr = np.squeeze(image.cpu().numpy()).astype(np.float32)\n        label_arr = target_soft.cpu().numpy().astype(np.uint8)\n        # label_shw = np.squeeze(target.cpu().numpy()).astype(np.uint8)\n        output_arr = output_soft.cpu().byte().numpy().astype(np.uint8)\n\n        isic_b_dice = val_dice_isic(output_soft, target_soft, args.num_classes)  # the dice accuracy\n        isic_b_iou = Intersection_over_Union_isic(output_soft, target_soft, args.num_classes)  # the iou accuracy\n        isic_b_asd = assd(output_arr[:, :, :, 1], label_arr[:, :, :, 1])\n\n        dice_np = isic_b_dice.data.cpu().numpy()\n        iou_np = isic_b_iou.data.cpu().numpy()\n        isic_dice.append(dice_np)\n        isic_iou.append(iou_np)\n        isic_assd.append(isic_b_asd)\n\n    # df = pd.DataFrame(data=dice_np)\n    # df.to_csv(args.ckpt + '/refine_result.csv')\n    isic_dice_mean = np.average(isic_dice)\n    isic_dice_std = np.std(isic_dice)\n\n    isic_iou_mean = np.average(isic_iou)\n    isic_iou_std = np.std(isic_iou)\n\n    isic_assd_mean = np.average(isic_assd)\n    isic_assd_std = np.std(isic_assd)\n\n    all_time = np.sum(infer_time)\n    print('The ISIC mean Accuracy: {isic_dice_mean: .4f}; The ISIC Accuracy std: {isic_dice_std: .4f}'.format(\n        isic_dice_mean=isic_dice_mean, isic_dice_std=isic_dice_std))\n    print('The ISIC mean IoU: {isic_iou_mean: .4f}; The ISIC IoU std: {isic_iou_std: .4f}'.format(\n        isic_iou_mean=isic_iou_mean, isic_iou_std=isic_iou_std))\n    print('The ISIC mean assd: {isic_asd_mean: .4f}; The ISIC assd std: {isic_asd_std: .4f}'.format(\n        isic_asd_mean=isic_assd_mean, isic_asd_std=isic_assd_std))\n    print('The inference time: {time: .4f}'.format(time=all_time))\n\nif __name__ == '__main__':\n    assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), 'PyTorch>=0.4.0 is required'\n\n    parser = argparse.ArgumentParser(description='U-net add Attention mechanism for biomedical Dataset')\n    # Model related arguments\n    parser.add_argument('--id', default='Comp_Atten_Unet',\n                        help='a name for identitying the model. Choose from the following options: Unet_fetus')\n    # Path related arguments\n    parser.add_argument('--root_path', default='./data/ISIC2018_Task1_npy_all',\n                        help='root directory of data')\n    parser.add_argument('--ckpt', default='./saved_models',\n                        help='folder to output checkpoints')\n    parser.add_argument('--save', default='./result',\n                        help='folder to outoput result')\n    parser.add_argument('--batch_size', type=int, default=1, metavar='N',\n                        help='input batch size for training (default: 16)')\n    parser.add_argument('--num_classes', default=2, type=int,\n                        help='number of classes')\n    parser.add_argument('--num_input', default=3, type=int,\n                        help='number of input image for each patient')\n    parser.add_argument('--epoch', type=int, default=300, metavar='N',\n                        help='choose the specific epoch checkpoints')\n\n    # other arguments\n    parser.add_argument('--data', default='ISIC2018', help='choose the dataset')\n    parser.add_argument('--out_size', default=(224, 300), help='the output image size')\n    parser.add_argument('--att_pos', default='dec', type=str,\n                        help='where attention to plug in (enc, dec, enc\\&dec)')\n    parser.add_argument('--view', default='axial', type=str,\n                        help='use what views data to test (for fetal MRI)')\n    parser.add_argument('--val_folder', default='folder0', type=str,\n                        help='which cross validation folder')\n\n    args = parser.parse_args()\n    args.ckpt = os.path.join(args.ckpt, args.data, args.val_folder, args.id)\n\n    # loading the dataset\n    print('loading the {0} dataset ...'.format('test'))\n    testset = Test_Dataset[args.data](dataset_folder=args.root_path, folder=args.val_folder, train_type='test', transform=Test_Transform[args.data])\n    testloader = Data.DataLoader(dataset=testset, batch_size=args.batch_size, shuffle=False)\n    print('Loading is done\\n')\n\n    # Define model\n    if torch.cuda.is_available():\n        print('We can use', torch.cuda.device_count(), 'GPUs to train the network')\n        if args.data == 'Fetus':\n            args.num_input = 1\n            args.num_classes = 3\n            args.out_size = (256, 256)\n        elif args.data == 'ISIC2018':\n            args.num_input = 3\n            args.num_classes = 2\n            args.out_size = (224, 300)\n        model = Test_Model[args.id](args, args.num_input, args.num_classes).cuda()\n        # model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))\n\n    # Load the trained best model\n    modelname = args.ckpt + '/' + 'min_loss' + '_' + args.data + '_checkpoint.pth.tar'\n    if os.path.isfile(modelname):\n        print(\"=> Loading checkpoint '{}'\".format(modelname))\n        checkpoint = torch.load(modelname)\n        # start_epoch = checkpoint['epoch']\n\n        # multi-GPU transfer to one GPU\n        # model_dict = model.state_dict()\n        # pretrained_dict = checkpoint['state_dict']\n        # from collections import OrderedDict\n        # new_state_dict = OrderedDict()\n        # for k, v in pretrained_dict.items():\n        #     name = k[7:]\n        #     new_state_dict[name] = v\n        #\n        # model_dict.update(new_state_dict)\n        # model.load_state_dict(model_dict)\n        model.load_state_dict(checkpoint['state_dict'])\n        # optimizer.load_state_dict(checkpoint['opt_dict'])\n        print(\"=> Loaded saved the best model at (epoch {})\".format(checkpoint['epoch']))\n    else:\n        print(\"=> No checkpoint found at '{}'\".format(modelname))\n\n    test_isic(testloader, model)\n"
  }
]