[
  {
    "path": ".idea/XNet.iml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager\">\n    <content url=\"file://$MODULE_DIR$\" />\n    <orderEntry type=\"jdk\" jdkName=\"Python 3.7\" jdkType=\"Python SDK\" />\n    <orderEntry type=\"sourceFolder\" forTests=\"false\" />\n  </component>\n  <component name=\"TestRunnerService\">\n    <option name=\"PROJECT_TEST_RUNNER\" value=\"Unittests\" />\n  </component>\n</module>"
  },
  {
    "path": ".idea/deployment.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"PublishConfigData\">\n    <serverData>\n      <paths name=\"EM1\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"EM2\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"GPU0\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"GPU4\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"GPU5\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"N22\">\n        <serverdata>\n          <mappings>\n            <mapping local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n      <paths name=\"N30\">\n        <serverdata>\n          <mappings>\n            <mapping deploy=\"/run/XNet\" local=\"$PROJECT_DIR$\" web=\"/\" />\n          </mappings>\n        </serverdata>\n      </paths>\n    </serverData>\n  </component>\n</project>"
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "content": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n    <version value=\"1.0\" />\n  </settings>\n</component>"
  },
  {
    "path": ".idea/misc.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"JavaScriptSettings\">\n    <option name=\"languageLevel\" value=\"ES6\" />\n  </component>\n  <component name=\"ProjectRootManager\" version=\"2\" project-jdk-name=\"Python 3.7\" project-jdk-type=\"Python SDK\" />\n</project>"
  },
  {
    "path": ".idea/modules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n      <module fileurl=\"file://$PROJECT_DIR$/.idea/XNet.iml\" filepath=\"$PROJECT_DIR$/.idea/XNet.iml\" />\n    </modules>\n  </component>\n</project>"
  },
  {
    "path": ".idea/vcs.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping directory=\"$PROJECT_DIR$\" vcs=\"Git\" />\n  </component>\n</project>"
  },
  {
    "path": ".idea/workspace.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"true\" id=\"a6084e43-8a3d-4f65-bc86-b2baf8115879\" name=\"Default Changelist\" comment=\"\">\n      <change afterPath=\"$PROJECT_DIR$/.idea/vcs.xml\" afterDir=\"false\" />\n      <change afterPath=\"$PROJECT_DIR$/train_sup_alnet.py\" afterDir=\"false\" />\n      <change afterPath=\"$PROJECT_DIR$/train_sup_wds.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/.idea/workspace.xml\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/.idea/workspace.xml\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/README.md\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/README.md\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/config/augmentation/online_aug.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/config/augmentation/online_aug.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/dataload/dataset_2d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/dataload/dataset_2d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/figure/Figure 3 v11.png\" beforeDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/figure/figure 1 v2.png\" beforeDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/figure/figure 2.png\" beforeDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/models/__init__.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/models/__init__.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/models/getnetwork.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/models/getnetwork.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/test_xnet.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/test_xnet.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/test_xnet3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/test_xnet3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/tools/eval.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/tools/eval.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/tools/wavelet2D.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/tools/wavelet2D.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/tools/wavelet3D.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/tools/wavelet3D.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CCT.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CCT.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CCT_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CCT_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CPS.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CPS.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CPS_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CPS_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CT.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CT.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_CT_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_CT_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_DTC.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_DTC.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_EM.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_EM.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_EM_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_EM_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_MT.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_MT.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_MT_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_MT_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_UAMT.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_UAMT.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_UAMT_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_UAMT_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_URPC.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_URPC.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_URPC_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_URPC_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_XNet.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_XNet.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_semi_XNet3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_semi_XNet3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup_3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_ConResNet.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup_ConResNet.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_XNet.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup_XNet.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_XNet3d.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup_XNet3d.py\" afterDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_XNet_EM.py\" beforeDir=\"false\" />\n      <change beforePath=\"$PROJECT_DIR$/train_sup_XNet_sb.py\" beforeDir=\"false\" afterPath=\"$PROJECT_DIR$/train_sup_XNet_sb.py\" afterDir=\"false\" />\n    </list>\n    <option name=\"EXCLUDED_CONVERTED_TO_IGNORED\" value=\"true\" />\n    <option name=\"SHOW_DIALOG\" value=\"false\" />\n    <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n    <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n    <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n  </component>\n  <component name=\"FileTemplateManagerImpl\">\n    <option name=\"RECENT_TEMPLATES\">\n      <list>\n        <option value=\"Python Script\" />\n      </list>\n    </option>\n  </component>\n  <component name=\"Git.Settings\">\n    <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n  </component>\n  <component name=\"ProjectId\" id=\"2ASvhrldQmnvrIuEBF6ewe1Fctz\" />\n  <component name=\"ProjectLevelVcsManager\" settingsEditedManually=\"true\" />\n  <component name=\"PropertiesComponent\">\n    <property name=\"SHARE_PROJECT_CONFIGURATION_FILES\" value=\"true\" />\n    <property name=\"WebServerToolWindowFactoryState\" value=\"false\" />\n    <property name=\"last_opened_file_path\" value=\"$PROJECT_DIR$\" />\n    <property name=\"settings.editor.selected.configurable\" value=\"com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable\" />\n  </component>\n  <component name=\"RecentsManager\">\n    <key name=\"MoveFile.RECENT_KEYS\">\n      <recent name=\"D:\\Desktop\\XNet\\tools\\LiTS\" />\n      <recent name=\"D:\\Desktop\\XNet\\tools\\Atrial\" />\n      <recent name=\"D:\\Desktop\\XNet\\tools\" />\n      <recent name=\"D:\\Desktop\\XNet\\models\\2d_networks\" />\n      <recent name=\"D:\\Desktop\\XNet\\tools\\CREMI\" />\n    </key>\n  </component>\n  <component name=\"RunDashboard\">\n    <option name=\"ruleStates\">\n      <list>\n        <RuleState>\n          <option name=\"name\" value=\"ConfigurationTypeDashboardGroupingRule\" />\n        </RuleState>\n        <RuleState>\n          <option name=\"name\" value=\"StatusDashboardGroupingRule\" />\n        </RuleState>\n      </list>\n    </option>\n  </component>\n  <component name=\"RunManager\" selected=\"Python.demo\">\n    <configuration name=\"demo\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"XNet\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/demo.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <configuration name=\"demo1\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"XNet\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/demo1.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <configuration name=\"hrnet\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"XNet\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$/models/networks_2d\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/models/networks_2d/hrnet.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <configuration name=\"vis\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"XNet\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$/tools/Atrial\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/tools/Atrial/vis.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <configuration name=\"xnet\" type=\"PythonConfigurationType\" factoryName=\"Python\" temporary=\"true\">\n      <module name=\"XNet\" />\n      <option name=\"INTERPRETER_OPTIONS\" value=\"\" />\n      <option name=\"PARENT_ENVS\" value=\"true\" />\n      <envs>\n        <env name=\"PYTHONUNBUFFERED\" value=\"1\" />\n      </envs>\n      <option name=\"SDK_HOME\" value=\"\" />\n      <option name=\"WORKING_DIRECTORY\" value=\"$PROJECT_DIR$/models/networks_2d\" />\n      <option name=\"IS_MODULE_SDK\" value=\"true\" />\n      <option name=\"ADD_CONTENT_ROOTS\" value=\"true\" />\n      <option name=\"ADD_SOURCE_ROOTS\" value=\"true\" />\n      <EXTENSION ID=\"PythonCoverageRunConfigurationExtension\" runner=\"coverage.py\" />\n      <option name=\"SCRIPT_NAME\" value=\"$PROJECT_DIR$/models/networks_2d/xnet.py\" />\n      <option name=\"PARAMETERS\" value=\"\" />\n      <option name=\"SHOW_COMMAND_LINE\" value=\"false\" />\n      <option name=\"EMULATE_TERMINAL\" value=\"false\" />\n      <option name=\"MODULE_MODE\" value=\"false\" />\n      <option name=\"REDIRECT_INPUT\" value=\"false\" />\n      <option name=\"INPUT_FILE\" value=\"\" />\n      <method v=\"2\" />\n    </configuration>\n    <recent_temporary>\n      <list>\n        <item itemvalue=\"Python.demo\" />\n        <item itemvalue=\"Python.hrnet\" />\n        <item itemvalue=\"Python.xnet\" />\n        <item itemvalue=\"Python.demo1\" />\n        <item itemvalue=\"Python.vis\" />\n      </list>\n    </recent_temporary>\n  </component>\n  <component name=\"SvnConfiguration\">\n    <configuration />\n  </component>\n  <component name=\"TaskManager\">\n    <task active=\"true\" id=\"Default\" summary=\"Default task\">\n      <changelist id=\"a6084e43-8a3d-4f65-bc86-b2baf8115879\" name=\"Default Changelist\" comment=\"\" />\n      <created>1655015922020</created>\n      <option name=\"number\" value=\"Default\" />\n      <option name=\"presentableId\" value=\"Default\" />\n      <updated>1655015922020</updated>\n      <workItem from=\"1655015923192\" duration=\"37990000\" />\n      <workItem from=\"1655341600741\" duration=\"188507000\" />\n      <workItem from=\"1657943138956\" duration=\"168310000\" />\n      <workItem from=\"1660184258941\" duration=\"265576000\" />\n      <workItem from=\"1661862792193\" duration=\"353353000\" />\n      <workItem from=\"1664521887562\" duration=\"3255000\" />\n      <workItem from=\"1665197785788\" duration=\"606000\" />\n      <workItem from=\"1665400082786\" duration=\"606000\" />\n      <workItem from=\"1665932248031\" duration=\"4039000\" />\n      <workItem from=\"1665987707938\" duration=\"6000\" />\n      <workItem from=\"1665987723627\" duration=\"1392000\" />\n      <workItem from=\"1665996697302\" duration=\"1393000\" />\n      <workItem from=\"1666072441496\" duration=\"745000\" />\n      <workItem from=\"1676863559005\" duration=\"498000\" />\n      <workItem from=\"1690253128423\" duration=\"1477000\" />\n      <workItem from=\"1690260775935\" duration=\"683000\" />\n      <workItem from=\"1690263220286\" duration=\"1342000\" />\n    </task>\n    <servers />\n  </component>\n  <component name=\"TypeScriptGeneratedFilesManager\">\n    <option name=\"version\" value=\"1\" />\n  </component>\n  <component name=\"Vcs.Log.Tabs.Properties\">\n    <option name=\"TAB_STATES\">\n      <map>\n        <entry key=\"MAIN\">\n          <value>\n            <State>\n              <option name=\"COLUMN_ORDER\" />\n            </State>\n          </value>\n        </entry>\n      </map>\n    </option>\n  </component>\n  <component name=\"com.intellij.coverage.CoverageDataManagerImpl\">\n    <SUITE FILE_PATH=\"coverage/XNet$wavelet_trans.coverage\" NAME=\"wavelet_trans Coverage Results\" MODIFIED=\"1656748230352\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/GlaS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$res_image_mask.coverage\" NAME=\"res_image_mask Coverage Results\" MODIFIED=\"1663396533063\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$wavelet_trans__2_.coverage\" NAME=\"wavelet_trans (2) Coverage Results\" MODIFIED=\"1657178663899\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/ISIC-2017\" />\n    <SUITE FILE_PATH=\"coverage/XNet$postprocess.coverage\" NAME=\"postprocess Coverage Results\" MODIFIED=\"1663946921242\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/Atrial\" />\n    <SUITE FILE_PATH=\"coverage/XNet$demo1.coverage\" NAME=\"demo1 Coverage Results\" MODIFIED=\"1664027395517\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n    <SUITE FILE_PATH=\"coverage/XNet$postprocess__1_.coverage\" NAME=\"postprocess (1) Coverage Results\" MODIFIED=\"1663998882582\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/LiTS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$xnet3d.coverage\" NAME=\"xnet3d Coverage Results\" MODIFIED=\"1662386984840\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$window_crop.coverage\" NAME=\"window_crop Coverage Results\" MODIFIED=\"1656508185058\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CREMI\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mask2sdf.coverage\" NAME=\"mask2sdf Coverage Results\" MODIFIED=\"1663901733624\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$preprocess__2_.coverage\" NAME=\"preprocess (2) Coverage Results\" MODIFIED=\"1661672037231\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CBMI-CA\" />\n    <SUITE FILE_PATH=\"coverage/XNet$wavelet3D.coverage\" NAME=\"wavelet3D Coverage Results\" MODIFIED=\"1662713970023\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$split_train_val.coverage\" NAME=\"split_train_val Coverage Results\" MODIFIED=\"1662220162317\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CBMI-CA\" />\n    <SUITE FILE_PATH=\"coverage/XNet$resunet_plusplus.coverage\" NAME=\"resunet_plusplus Coverage Results\" MODIFIED=\"1660993443784\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unetr.coverage\" NAME=\"unetr Coverage Results\" MODIFIED=\"1661761635389\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$vnet_cct.coverage\" NAME=\"vnet_cct Coverage Results\" MODIFIED=\"1663554055273\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mask_color.coverage\" NAME=\"mask_color Coverage Results\" MODIFIED=\"1655357975942\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mean_std.coverage\" NAME=\"mean_std Coverage Results\" MODIFIED=\"1657175540046\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CREMI\" />\n    <SUITE FILE_PATH=\"coverage/XNet$split_semi__1_.coverage\" NAME=\"split_semi (1) Coverage Results\" MODIFIED=\"1661613383049\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/Atrial\" />\n    <SUITE FILE_PATH=\"coverage/XNet$split_semi.coverage\" NAME=\"split_semi Coverage Results\" MODIFIED=\"1661940512952\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/ISIC-2017\" />\n    <SUITE FILE_PATH=\"coverage/XNet$16bitto8bit.coverage\" NAME=\"16bitto8bit Coverage Results\" MODIFIED=\"1655346328619\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$demo.coverage\" NAME=\"demo Coverage Results\" MODIFIED=\"1664615715686\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unet3d_dtc.coverage\" NAME=\"unet3d_dtc Coverage Results\" MODIFIED=\"1663903842933\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$nnunet.coverage\" NAME=\"nnunet Coverage Results\" MODIFIED=\"1661924361071\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$kiunet__1_.coverage\" NAME=\"kiunet (1) Coverage Results\" MODIFIED=\"1661915451342\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mean_std__2_.coverage\" NAME=\"mean_std (2) Coverage Results\" MODIFIED=\"1656504773540\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/ISIC-2017\" />\n    <SUITE FILE_PATH=\"coverage/XNet$preprocess__1_.coverage\" NAME=\"preprocess (1) Coverage Results\" MODIFIED=\"1662224730108\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/LiTS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mask_pro__1_.coverage\" NAME=\"mask_pro (1) Coverage Results\" MODIFIED=\"1656640935024\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/ISIC-2017\" />\n    <SUITE FILE_PATH=\"coverage/XNet$bound.coverage\" NAME=\"bound Coverage Results\" MODIFIED=\"1661946279003\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$cotr.coverage\" NAME=\"cotr Coverage Results\" MODIFIED=\"1661845197933\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$vis.coverage\" NAME=\"vis Coverage Results\" MODIFIED=\"1664027158868\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/Atrial\" />\n    <SUITE FILE_PATH=\"coverage/XNet$vnet_dtc.coverage\" NAME=\"vnet_dtc Coverage Results\" MODIFIED=\"1663903484888\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$wavelet_trans__1_.coverage\" NAME=\"wavelet_trans (1) Coverage Results\" MODIFIED=\"1656899007333\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CREMI\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unet_3plus.coverage\" NAME=\"unet_3plus Coverage Results\" MODIFIED=\"1661000368219\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$nnunet3d.coverage\" NAME=\"nnunet3d Coverage Results\" MODIFIED=\"1661925089667\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$split_train_val__1_.coverage\" NAME=\"split_train_val (1) Coverage Results\" MODIFIED=\"1662225861524\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/LiTS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unet_urpc.coverage\" NAME=\"unet_urpc Coverage Results\" MODIFIED=\"1659324102938\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models\" />\n    <SUITE FILE_PATH=\"coverage/XNet$u2net.coverage\" NAME=\"u2net Coverage Results\" MODIFIED=\"1661000154317\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$espnet3d.coverage\" NAME=\"espnet3d Coverage Results\" MODIFIED=\"1661921901877\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$transbts.coverage\" NAME=\"transbts Coverage Results\" MODIFIED=\"1661848844623\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$kiunet3d.coverage\" NAME=\"kiunet3d Coverage Results\" MODIFIED=\"1662030942646\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mask_pro.coverage\" NAME=\"mask_pro Coverage Results\" MODIFIED=\"1656516186672\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/GlaS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unsup_split.coverage\" NAME=\"unsup_split Coverage Results\" MODIFIED=\"1655464356868\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$dfmnet.coverage\" NAME=\"dfmnet Coverage Results\" MODIFIED=\"1661845257224\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$xnet_single_branch.coverage\" NAME=\"xnet_single_branch Coverage Results\" MODIFIED=\"1656514918760\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models\" />\n    <SUITE FILE_PATH=\"coverage/XNet$resunet.coverage\" NAME=\"resunet Coverage Results\" MODIFIED=\"1660993328964\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$conresnet.coverage\" NAME=\"conresnet Coverage Results\" MODIFIED=\"1662087140386\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$swinunet.coverage\" NAME=\"swinunet Coverage Results\" MODIFIED=\"1661924117865\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$xnet.coverage\" NAME=\"xnet Coverage Results\" MODIFIED=\"1664431291705\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$vae_seg.coverage\" NAME=\"vae_seg Coverage Results\" MODIFIED=\"1663923536933\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_3d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$split_dataset.coverage\" NAME=\"split_dataset Coverage Results\" MODIFIED=\"1662226197150\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools\" />\n    <SUITE FILE_PATH=\"coverage/XNet$vis__1_.coverage\" NAME=\"vis (1) Coverage Results\" MODIFIED=\"1663742524885\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/LiTS\" />\n    <SUITE FILE_PATH=\"coverage/XNet$preprocess.coverage\" NAME=\"preprocess Coverage Results\" MODIFIED=\"1662220017369\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/CBMI-CA\" />\n    <SUITE FILE_PATH=\"coverage/XNet$mean_std__1_.coverage\" NAME=\"mean_std (1) Coverage Results\" MODIFIED=\"1657175481052\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/tools/ISIC-2017\" />\n    <SUITE FILE_PATH=\"coverage/XNet$kiunet.coverage\" NAME=\"kiunet Coverage Results\" MODIFIED=\"1661933613759\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$hrnet.coverage\" NAME=\"hrnet Coverage Results\" MODIFIED=\"1664431322196\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models/networks_2d\" />\n    <SUITE FILE_PATH=\"coverage/XNet$unet_cct.coverage\" NAME=\"unet_cct Coverage Results\" MODIFIED=\"1659324460896\" SOURCE_PROVIDER=\"com.intellij.coverage.DefaultCoverageFileProvider\" RUNNER=\"coverage.py\" COVERAGE_BY_TEST_ENABLED=\"true\" COVERAGE_TRACING_ENABLED=\"false\" WORKING_DIRECTORY=\"$PROJECT_DIR$/models\" />\n  </component>\n</project>"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Yanfeng Zhou\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "﻿\n# XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images\n\nThis is the official code of [XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images](https://openaccess.thecvf.com/content/ICCV2023/html/Zhou_XNet_Wavelet-Based_Low_and_High_Frequency_Fusion_Networks_for_Fully-_ICCV_2023_paper.html) (ICCV 2023).\n\n## Overview\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Architecture%20of%20XNet.png\" width=\"100%\" ></img>\n<br>Architecture of XNet.\n</p>\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/visualize%20LF%20and%20HF%20images.png\" width=\"100%\" ></img>\n<br>Visualize dual-branch inputs. (a) Raw image. (b) Wavelet transform results. (c) Low frequency image. (d) High frequency image.\n</p>\n\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Architecture%20of%20LF%20and%20HF%20fusion%20module.png\" width=\"50%\" ></img>\n<br>Architecture of LF and HF fusion module.\n</p>\n\n\n## Quantitative Comparison\n\nComparison with fully- and semi-supervised state-of-the-art models on GlaS and CREMI test set. Semi-supervised models are based on UNet. DS indicates deep supervision. * indicates lightweight models. ‡ indicates training for 1000 epochs. - indicates training failed. <font color=\"Red\">**Red**</font> and **bold** indicate the best and second best performance.\n\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Comparison%20results%20on%20GlaS%20and%20CREMI.png\" width=\"100%\" >\n</p>\n\nComparison with fully- and semi-supervised state-of-the-art models on LA and LiTS test set. Due to GPU memory limitations, some semi-supervised models using smaller architectures, ✝ and * indicate models are based on lightweight 3D UNet (half of channels) and VNet, respectively. ‡ indicates training for 1000 epochs. - indicates training failed. <font color=\"Red\">**Red**</font> and **bold** indicate the best and second best performance.\n\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Comparison%20results%20on%20LA%20and%20P-CT.png\" width=\"100%\" >\n</p>\n\n## Qualitative Comparison\n\n<p align=\"center\">\n<img src=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Qualitative%20results.png\" width=\"100%\" >\n<br>Qualitative results on GIaS, CREMI, LA and LiTS. (a) Raw images. (b) Ground truth. (c) MT. (d) Semi-supervised XNet (3D XNet). (e) UNet (3D UNet). (f) Fully-Supervised XNet (3D XNet). The orange arrows highlight the difference among of the results.\n</p>\n\n## Reimplemented Architecture\nWe have reimplemented some 2D and 3D models in semi- and supervised semantic segmentation.\n<table>\n<tr><th align=\"left\">Method</th> <th align=\"left\">Dimension</th><th align=\"left\">Model</th><th align=\"left\">Code</th></tr>\n<tr><td rowspan=\"23\">Supervised</td> <td rowspan=\"13\">2D</td><td>UNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet.py\">models/networks_2d/unet.py</a></td></tr>\n<tr><td>UNet++</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet_plusplus.py\">models/networks_2d/unet_plusplus.py</a></td></tr>\n<tr><td>Att-UNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet.py\">models/networks_2d/unet.py</a></td></tr>\n<tr><td>Aerial LaneNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/aerial_lanenet.py\">models/networks_2d/aerial_lanenet.py</a></td></tr>\n<tr><td>MWCNN</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/mwcnn.py\">models/networks_2d/mwcnn.py</a></td></tr>\n<tr><td>HRNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/hrnet.py\">models/networks_2d/hrnet.py</a></td></tr>\n<tr><td>Res-UNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/resunet.py\">models/networks_2d/resunet.py</a></td></tr>\n<tr><td>WDS</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/wds.py\">models/networks_2d/wds.py</a></td></tr>\n<tr><td>U<sup>2</sup>-Net</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/u2net.py\">models/networks_2d/u2net.py</a></td></tr>\n<tr><td>UNet 3+</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet_3plus.py\">models/networks_2d/unet_3plus.py</a></td></tr>\n<tr><td>SwinUNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/swinunet.py\">models/networks_2d/swinunet.py</a></td></tr>\n<tr><td>WaveSNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/wavesnet.py\">models/networks_2d/wavesnet.py</a></td></tr>\n<tr><td>XNet (Ours)</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/xnet.py\">models/networks_2d/xnet.py</a></td></tr>\n<tr><td rowspan=\"10\">3D</td><td>VNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/vnet.py\">models/networks_3d/vnet.py</a></td></tr>\n<tr><td>UNet 3D</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/unet3d.py\">models/networks_3d/unet3d.py</a></td></tr>\n<tr><td>Res-UNet 3D</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/res_unet3d.py\">models/networks_3d/res_unet3d.py</a></td></tr>\n<tr><td>ESPNet 3D</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/espnet3d.py\">models/networks_3d/espnet3d.py</a></td></tr>\n<tr><td>DMFNet 3D</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/dmfnet.py\">models/networks_3d/dmfnet.py</a></td></tr>\n<tr><td>ConResNet</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/conresnet.py\">models/networks_3d/conresnet.py</a></td></tr>\n<tr><td>CoTr</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/cotr.py\">models/networks_3d/cotr.py</a></td></tr>\n<tr><td>TransBTS</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/transbts.py\">models/networks_3d/transbts.py</a></td></tr>\n<tr><td>UNETR</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/unetr.py\">models/networks_3d/unetr.py</a></td></tr>\n<tr><td>XNet 3D (Ours)</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/xnet3d.py\">models/networks_3d/xnet3d.py</a></td></tr>\n<tr><td rowspan=\"17\">Semi-Supervised</td> <td rowspan=\"8\">2D</td><td>MT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_MT.py\">train_semi_MT.py</a></td></tr>\n<tr><td>EM</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_EM.py\">train_semi_EM.py</a></td></tr>\n<tr><td>UAMT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_UAMT.py\">train_semi_UAMT.py</a></td></tr>\n<tr><td>CCT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CCT.py\">train_semi_CCT.py</a></td></tr>\n<tr><td>CPS</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CPS.py\">train_semi_CPS.py</a></td></tr>\n<tr><td>URPC</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_URPC.py\">train_semi_URPC.py</a></td></tr>\n<tr><td>CT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CT.py\">train_semi_CT.py</a></td></tr>\n<tr><td>XNet (Ours)</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_XNet.py\">train_semi_XNet.py</a></td></tr>\n<td rowspan=\"9\">3D</td><td>MT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_MT_3d.py\">train_semi_MT_3d.py</a></td></tr>\n<tr><td>EM</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_EM_3d.py\">train_semi_EM_3d.py</a></td></tr>\n<tr><td>UAMT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_UAMT_3d.py\">train_semi_UAMT_3d.py</a></td></tr>\n<tr><td>CCT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CCT_3d.py\">train_semi_CCT_3d.py</a></td></tr>\n<tr><td>CPS</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CPS_3d.py\">train_semi_CPS_3d.py</a></td></tr>\n<tr><td>URPC</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_URPC_3d.py\">train_semi_URPC_3d.py</a></td></tr>\n<tr><td>CT</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CT_3d.py\">train_semi_CT_3d.py</a></td></tr>\n<tr><td>DTC</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_DTC.py\">train_semi_DTC.py</a></td></tr>\n<tr><td>XNet 3D (Ours)</td><td><a href=\"https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_XNet3d.py\">train_semi_XNet3d.py</a></td></tr>\n</table>\n\n## Requirements\n```\nalbumentations==0.5.2\neinops==0.4.1\nMedPy==0.4.0\nnumpy==1.20.2\nopencv_python==4.2.0.34\nopencv_python_headless==4.5.1.48\nPillow==8.0.0\nPyWavelets==1.1.1\nscikit_image==0.18.1\nscikit_learn==1.0.1\nscipy==1.4.1\nSimpleITK==2.1.0\ntimm==0.6.7\ntorch==1.8.0+cu111\ntorchio==0.18.53\ntorchvision==0.9.0+cu111\ntqdm==4.65.0\nvisdom==0.1.8.9\n```\n\n## Usage\n**Data preparation**\nYour datasets directory tree should be look like this:\n>to see [tools/wavelet2D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet2D.py) and  [tools/wavelet3D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet3D.py) for **L** and **H**\n```\ndataset\n├── train_sup_100\n    ├── L\n        ├── 1.tif\n        ├── 2.tif\n        └── ...\n    ├── H\n        ├── 1.tif\n        ├── 2.tif\n        └── ...\n    └── mask\n        ├── 1.tif\n        ├── 2.tif\n        └── ...\n├── train_sup_20\n    ├── L\n    ├── H\n    └── mask\n├── train_unsup_80\n    └── L\n    ├── H\n└── val\n    ├── L\n    ├── H\n    └── mask\n```\n**Supervised training**\n```\npython -m torch.distributed.launch --nproc_per_node=4 train_sup_XNet.py\n```\n**Semi-supervised training**\n```\npython -m torch.distributed.launch --nproc_per_node=4 train_semi_XNet.py\n```\n**Testing**\n```\npython -m torch.distributed.launch --nproc_per_node=4 test.py\n```\n\n## Citation\nIf our work is useful for your research, please cite our paper:\n```\n@InProceedings{Zhou_2023_ICCV,\n  author = {Zhou, Yanfeng and Huang, Jiaxing and Wang, Chenlong and Song, Le and Yang, Ge}, \n  title = {XNet: Wavelet-Based Low and High Frequency Fusion Networks for Fully- and Semi-Supervised Semantic Segmentation of Biomedical Images}, \n  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, \n  month = {October}, \n  year = {2023}, \n  pages = {21085-21096}\n  }\n```\n\n\n\n"
  },
  {
    "path": "config/__init__.py",
    "content": ""
  },
  {
    "path": "config/augmentation/__init__.py",
    "content": ""
  },
  {
    "path": "config/augmentation/online_aug.py",
    "content": "import albumentations as A\nfrom albumentations.pytorch import ToTensorV2\nfrom torchio import transforms as T\nimport torchio as tio\n\ndef data_transform_2d():\n    data_transforms = {\n        'train': A.Compose([\n            A.Resize(128, 128, p=1),\n            A.Flip(p=0.75),\n            A.Transpose(p=0.5),\n            A.RandomRotate90(p=1),\n        ],\n            additional_targets={'image2': 'image', 'mask2': 'mask'}\n        ),\n        'val': A.Compose([\n            A.Resize(128, 128, p=1),\n        ],\n            additional_targets={'image2': 'image', 'mask2': 'mask'}\n        ),\n        'test': A.Compose([\n            A.Resize(128, 128, p=1),\n        ],\n            additional_targets={'image2': 'image', 'mask2': 'mask'}\n        )\n    }\n    return data_transforms\n\n\ndef data_normalize_2d(mean, std):\n    data_normalize = A.Compose([\n            A.Normalize(mean, std),\n            ToTensorV2()\n        ],\n            additional_targets={'image2': 'image', 'mask2': 'mask'}\n    )\n    return data_normalize\n\ndef data_transform_aerial_lanenet(H, W):\n    data_transforms = A.Compose([\n            A.Resize(H, W, p=1),\n            ToTensorV2()\n        ])\n    return data_transforms\n\n\ndef data_transform_3d(normalization):\n    data_transform = {\n        'train': T.Compose([\n            T.RandomFlip(),\n            T.RandomBiasField(coefficients=(0.12, 0.15), order=2, p=0.2),\n            T.OneOf({\n               T.RandomNoise(): 0.5,\n               T.RandomBlur(std=1): 0.5,\n            }, p=0.2),\n            T.ZNormalization(masking_method=normalization),\n        ]),\n        'val': T.Compose([\n            # T.CropOrPad(pad_size),\n            T.ZNormalization(masking_method=normalization),\n            # T.Resize(target_shape=(512, 512, 512), p=1)\n        ]),\n        'test': T.Compose([\n            # T.CropOrPad(pad_size),\n            T.ZNormalization(masking_method=normalization),\n            # T.Resize(target_shape=(512, 512, 512), p=1)\n        ])\n    }\n\n    return data_transform"
  },
  {
    "path": "config/dataset_config/__init__.py",
    "content": ""
  },
  {
    "path": "config/dataset_config/dataset_cfg.py",
    "content": "import numpy as np\nimport torchio as tio\n\ndef dataset_cfg(dataet_name):\n\n    config = {\n        'CREMI':\n            {\n                'IN_CHANNELS': 1,\n                'NUM_CLASSES': 2,\n                'MEAN': [0.503902],\n                'STD': [0.110739],\n                'MEAN_DB2_H': [0.505787],\n                'STD_DB2_H': [0.115504],\n                'PALETTE': list(np.array([\n                    [255, 255, 255],\n                    [0, 0, 0],\n                ]).flatten())\n            },\n        'GlaS':\n            {\n                'IN_CHANNELS': 3,\n                'NUM_CLASSES': 2,\n                'MEAN': [0.787803, 0.512017, 0.784938],\n                'STD': [0.428206, 0.507778, 0.426366],\n                'MEAN_HAAR_H': [0.528318],\n                'STD_HAAR_H': [0.076766],\n                'MEAN_HAAR_L': [0.579144],\n                'STD_HAAR_L': [0.227451],\n                'MEAN_HAAR_HHL': [0.542428],\n                'STD_HAAR_HHL': [0.142663],\n                'MEAN_HAAR_HLL': [0.569150],\n                'STD_HAAR_HLL': [0.220854],\n                'MEAN_BIOR1.5_H': [0.525711],\n                'STD_BIOR1.5_H': [0.076606],\n                'MEAN_BIOR2.4_H': [0.516579],\n                'STD_BIOR2.4_H': [0.078798],\n                'MEAN_COIF1_H': [0.523858],\n                'STD_COIF1_H': [0.081001],\n                'MEAN_DB2_H': [0.505234],\n                'STD_DB2_H': [0.080919],\n                'MEAN_DMEY_H': [0.502698],\n                'STD_DMEY_H': [0.078861],\n                'PALETTE': list(np.array([\n                    [0, 0, 0],\n                    [255, 255, 255],\n                ]).flatten())\n            },\n        'ISIC-2017':\n            {\n                'IN_CHANNELS': 3,\n                'NUM_CLASSES': 2,\n                'MEAN': [0.699002, 0.556046, 0.512134],\n                'STD': [0.365650, 0.317347, 0.339400],\n                'MEAN_DB2_H': [0.489676],\n                'STD_DB2_H': [0.081749],\n                'PALETTE': list(np.array([\n                    [0, 0, 0],\n                    [255, 255, 255],\n                ]).flatten())\n            },\n        'LiTS':\n            {\n                'IN_CHANNELS': 1,\n                'NUM_CLASSES': 3,\n                'NORMALIZE': tio.ZNormalization.mean,\n                'PATCH_SIZE': (112, 112, 32),\n                'FORMAT': '.nii',\n                'NUM_SAMPLE_TRAIN': 8,\n                'NUM_SAMPLE_VAL': 12\n            },\n        'Atrial':\n            {\n                'IN_CHANNELS': 1,\n                'NUM_CLASSES': 2,\n                'NORMALIZE': tio.ZNormalization.mean,\n                'PATCH_SIZE': (96, 96, 80),\n                'FORMAT': '.nrrd',\n                'NUM_SAMPLE_TRAIN': 4,\n                'NUM_SAMPLE_VAL': 8\n            },\n    }\n\n    return config[dataet_name]\n"
  },
  {
    "path": "config/eval_config/__init__.py",
    "content": ""
  },
  {
    "path": "config/eval_config/eval.py",
    "content": "import numpy as np\nfrom sklearn.metrics import confusion_matrix\nfrom scipy.spatial.distance import directed_hausdorff\nimport torch\n\n\ndef evaluate(y_scores, y_true, interval=0.02):\n\n    y_scores = torch.softmax(y_scores, dim=1)\n    y_scores = y_scores[:, 1, ...].cpu().detach().numpy().flatten()\n    y_true = y_true.data.cpu().numpy().flatten()\n\n    thresholds = np.arange(0, 0.9, interval)\n    jaccard = np.zeros(len(thresholds))\n    dice = np.zeros(len(thresholds))\n    y_true.astype(np.int8)\n\n    for indy in range(len(thresholds)):\n        threshold = thresholds[indy]\n        y_pred = (y_scores > threshold).astype(np.int8)\n\n        sum_area = (y_pred + y_true)\n        tp = float(np.sum(sum_area == 2))\n        union = np.sum(sum_area == 1)\n        jaccard[indy] = tp / float(union + tp)\n        dice[indy] = 2 * tp / float(union + 2 * tp)\n\n    thred_indx = np.argmax(jaccard)\n    m_jaccard = jaccard[thred_indx]\n    m_dice = dice[thred_indx]\n\n    return thresholds[thred_indx], m_jaccard, m_dice\n\n\n\ndef evaluate_multi(y_scores, y_true):\n\n    y_scores = torch.softmax(y_scores, dim=1)\n    y_pred = torch.max(y_scores, 1)[1]\n    y_pred = y_pred.data.cpu().numpy().flatten()\n    y_true = y_true.data.cpu().numpy().flatten()\n\n    hist = confusion_matrix(y_true, y_pred)\n\n    hist_diag = np.diag(hist)\n    hist_sum_0 = hist.sum(axis=0)\n    hist_sum_1 = hist.sum(axis=1)\n\n    jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag)\n    m_jaccard = np.nanmean(jaccard)\n    dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0)\n    m_dice = np.nanmean(dice)\n\n    return jaccard, m_jaccard, dice, m_dice\n\n\n\n\n"
  },
  {
    "path": "config/ramps/__init__.py",
    "content": ""
  },
  {
    "path": "config/ramps/ramps.py",
    "content": "import numpy as np\n\n\ndef sigmoid_rampup(current, rampup_length):\n    \"\"\"Exponential rampup from https://arxiv.org/abs/1610.02242\"\"\"\n    if rampup_length == 0:\n        return 1.0\n    else:\n        current = np.clip(current, 0.0, rampup_length)\n        phase = 1.0 - current / rampup_length\n        return float(np.exp(-5.0 * phase * phase))\n\n\ndef linear_rampup(current, rampup_length):\n    \"\"\"Linear rampup\"\"\"\n    assert current >= 0 and rampup_length >= 0\n    if current >= rampup_length:\n        return 1.0\n    else:\n        return current / rampup_length\n\n\ndef cosine_rampdown(current, rampdown_length):\n    \"\"\"Cosine rampdown from https://arxiv.org/abs/1608.03983\"\"\"\n    assert 0 <= current <= rampdown_length\n    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))\n"
  },
  {
    "path": "config/train_test_config/__init__.py",
    "content": ""
  },
  {
    "path": "config/train_test_config/train_test_config.py",
    "content": "import numpy as np\nfrom config.eval_config.eval import evaluate, evaluate_multi\nimport torch\nimport os\nfrom PIL import Image\nimport torchio as tio\n\ndef print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus):\n    train_epoch_loss = train_loss / num_batches['train_sup']\n    print('-' * print_num)\n    print('| Train Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')\n    print('-' * print_num)\n    return train_epoch_loss\n\ndef print_train_loss_MT(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_half, print_num_minus):\n    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']\n    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']\n    train_epoch_loss = train_loss / num_batches['train_sup']\n    print('-' * print_num)\n    print('| Train  Sup  Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '|')\n    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')\n    print('-' * print_num)\n    return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss\n\ndef print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus):\n    train_epoch_loss_seg = train_loss_seg / num_batches['train_sup']\n    train_epoch_loss_res = train_loss_res / num_batches['train_sup']\n    train_epoch_loss = train_loss / num_batches['train_sup']\n    print('-' * print_num)\n    print('| Train  Seg  Loss: {:.4f}'.format(train_epoch_loss_seg).ljust(print_num_half, ' '), '| Train Res Loss: {:.4f}'.format(train_epoch_loss_res).ljust(print_num_half, ' '), '|')\n    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')\n    print('-' * print_num)\n    return train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss\n\n\ndef print_train_loss_EM(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_minus):\n    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']\n    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']\n    train_epoch_loss = train_loss / num_batches['train_sup']\n    print('-' * print_num)\n    print('| Train  Sup  Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_minus, ' '), '|')\n    print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_minus, ' '), '|')\n    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')\n    print('-' * print_num)\n    return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss\n\n\ndef print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_cps, train_loss, num_batches, print_num, print_num_half):\n    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']\n    train_epoch_loss_sup2 = train_loss_sup_2 / num_batches['train_sup']\n    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']\n    train_epoch_loss = train_loss / num_batches['train_sup']\n    print('-' * print_num)\n    print('| Train Sup Loss 1: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train SUP Loss 2: {:.4f}'.format(train_epoch_loss_sup2).ljust(print_num_half, ' '), '|')\n    print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_half, ' '), '|')\n    print('-' * print_num)\n    return train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss\n\ndef print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus):\n    val_epoch_loss = val_loss / num_batches['val']\n    print('-' * print_num)\n    print('| Val Loss: {:.4f}'.format(val_epoch_loss).ljust(print_num_minus, ' '), '|')\n    print('-' * print_num)\n    return val_epoch_loss\n\ndef print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half):\n    val_epoch_loss_sup1 = val_loss_sup_1 / num_batches['val']\n    val_epoch_loss_sup2 = val_loss_sup_2 / num_batches['val']\n    print('-' * print_num)\n    print('| Val Sup Loss 1: {:.4f}'.format(val_epoch_loss_sup1).ljust(print_num_half, ' '), '| Val Sup Loss 2: {:.4f}'.format(val_epoch_loss_sup2).ljust(print_num_half, ' '), '|')\n    print('-' * print_num)\n    return val_epoch_loss_sup1, val_epoch_loss_sup2\n\ndef print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half):\n    val_epoch_loss_seg = val_loss_seg / num_batches['val']\n    val_epoch_loss_res = val_loss_res / num_batches['val']\n    print('-' * print_num)\n    print('| Val Seg Loss: {:.4f}'.format(val_epoch_loss_seg).ljust(print_num_half, ' '), '| Val Res Loss: {:.4f}'.format(val_epoch_loss_res).ljust(print_num_half, ' '), '|')\n    print('-' * print_num)\n    return val_epoch_loss_seg, val_epoch_loss_res\n\ndef print_train_eval_sup(num_classes, score_list_train, mask_list_train, print_num):\n\n    if num_classes == 2:\n        eval_list = evaluate(score_list_train, mask_list_train)\n        print('| Train Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Train  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Train  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')\n        train_m_jc = eval_list[1]\n\n    else:\n        eval_list = evaluate_multi(score_list_train, mask_list_train)\n\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Train  Jc: {}'.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Train  Dc: {}'.format(eval_list[2]).ljust(print_num, ' '), '|')\n        print('| Train mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Train mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')\n        train_m_jc = eval_list[1]\n\n    return eval_list, train_m_jc\n\ndef print_train_eval_XNet(num_classes, score_list_train1, score_list_train2, mask_list_train, print_num):\n\n    if num_classes == 2:\n        eval_list1 = evaluate(score_list_train1, mask_list_train)\n        eval_list2 = evaluate(score_list_train2, mask_list_train)\n        print('| Train Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')\n        print('| Train  Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train  Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')\n        print('| Train  Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train  Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')\n        train_m_jc1 = eval_list1[1]\n        train_m_jc2 = eval_list2[1]\n    else:\n        eval_list1 = evaluate_multi(score_list_train1, mask_list_train)\n        eval_list2 = evaluate_multi(score_list_train2, mask_list_train)\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Train  Jc 1: {}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train  Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')\n        print('| Train  Dc 1: {}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train  Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')\n        print('| Train mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')\n        print('| Train mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Train mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')\n        train_m_jc1 = eval_list1[1]\n        train_m_jc2 = eval_list2[1]\n\n    return eval_list1, eval_list2, train_m_jc1, train_m_jc2\n\ndef print_val_eval_sup(num_classes, score_list_val, mask_list_val, print_num):\n    if num_classes == 2:\n        eval_list = evaluate(score_list_val, mask_list_val)\n        print('| Val Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Val  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Val  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')\n        val_m_jc = eval_list[1]\n    else:\n        eval_list = evaluate_multi(score_list_val, mask_list_val)\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Val  Jc: {}  '.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Val  Dc: {}  '.format(eval_list[2]).ljust(print_num, ' '), '|')\n        print('| Val mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Val mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')\n        val_m_jc = eval_list[1]\n    return eval_list, val_m_jc\n\ndef print_val_eval(num_classes, score_list_val1, score_list_val2, mask_list_val, print_num):\n    if num_classes == 2:\n        eval_list1 = evaluate(score_list_val1, mask_list_val)\n        eval_list2 = evaluate(score_list_val2, mask_list_val)\n        print('| Val Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Val Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')\n        print('| Val  Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val  Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')\n        print('| Val  Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Val  Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')\n        val_m_jc1 = eval_list1[1]\n        val_m_jc2 = eval_list2[1]\n    else:\n        eval_list1 = evaluate_multi(score_list_val1, mask_list_val)\n        eval_list2 = evaluate_multi(score_list_val2, mask_list_val)\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Val  Jc 1: {}  '.format(eval_list1[0]).ljust(print_num, ' '), '| Val  Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')\n        print('| Val  Dc 1: {}  '.format(eval_list1[2]).ljust(print_num, ' '), '| Val  Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')\n        print('| Val mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')\n        print('| Val mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Val mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')\n        val_m_jc1 = eval_list1[1]\n        val_m_jc2 = eval_list2[1]\n    return eval_list1, eval_list2, val_m_jc1, val_m_jc2\n\ndef save_val_best_sup_2d(num_classes, best_list, model, score_list_val, name_list_val, eval_list, path_trained_model, path_seg_results, palette, model_name):\n\n    if num_classes == 2:\n        if best_list[1] < eval_list[1]:\n            best_list = eval_list\n\n            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))\n\n            score_list_val = torch.softmax(score_list_val, dim=1)\n            pred_results = score_list_val[:, 1, :, :].cpu().numpy()\n            pred_results[pred_results > eval_list[0]] = 1\n            pred_results[pred_results <= eval_list[0]] = 0\n\n            assert len(name_list_val) == pred_results.shape[0]\n            for i in range(len(name_list_val)):\n                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n                color_results.putpalette(palette)\n                color_results.save(os.path.join(path_seg_results, name_list_val[i]))\n\n    else:\n        if best_list[1] < eval_list[1]:\n            best_list = eval_list\n\n            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))\n\n            pred_results = torch.max(score_list_val, 1)[1]\n            pred_results = pred_results.cpu().numpy()\n\n            assert len(name_list_val) == pred_results.shape[0]\n            for i in range(len(name_list_val)):\n                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n                color_results.putpalette(palette)\n                color_results.save(os.path.join(path_seg_results, name_list_val[i]))\n\n    return best_list\n\ndef save_val_best_sup_3d(num_classes, best_list, model, score_list_val, mask_list_val, eval_list, path_trained_model, path_seg_results, path_mask_results, model_name, format):\n\n    if num_classes == 2:\n        if best_list[1] < eval_list[1]:\n            best_list = eval_list\n\n            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))\n\n    else:\n        if best_list[1] < eval_list[1]:\n            best_list = eval_list\n\n            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))\n\n    return best_list\n\ndef save_val_best_2d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, name_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, palette):\n\n    if eval_list_1[1] < eval_list_2[1]:\n        if best_list[1] < eval_list_2[1]:\n\n            best_model = model2\n            best_list = eval_list_2\n            best_result = 'Result2'\n\n            torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))\n\n            if num_classes == 2:\n                score_list_val_2 = torch.softmax(score_list_val_2, dim=1)\n                pred_results = score_list_val_2[:, 1, ...].cpu().numpy()\n                pred_results[pred_results > eval_list_2[0]] = 1\n                pred_results[pred_results <= eval_list_2[0]] = 0\n            else:\n                pred_results = torch.max(score_list_val_2, 1)[1]\n                pred_results = pred_results.cpu().numpy()\n\n            assert len(name_list_val) == pred_results.shape[0]\n            for i in range(len(name_list_val)):\n                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n                color_results.putpalette(palette)\n                color_results.save(os.path.join(path_seg_results, name_list_val[i]))\n        else:\n            best_model = best_model\n            best_list = best_list\n            best_result = best_result\n\n    else:\n        if best_list[1] < eval_list_1[1]:\n\n            best_model = model1\n            best_list = eval_list_1\n            best_result = 'Result1'\n\n            torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))\n\n            if num_classes == 2:\n                score_list_val_1 = torch.softmax(score_list_val_1, dim=1)\n                pred_results = score_list_val_1[:, 1, ...].cpu().numpy()\n                pred_results[pred_results > eval_list_1[0]] = 1\n                pred_results[pred_results <= eval_list_1[0]] = 0\n            else:\n                pred_results = torch.max(score_list_val_1, 1)[1]\n                pred_results = pred_results.cpu().numpy()\n\n            assert len(name_list_val) == pred_results.shape[0]\n            for i in range(len(name_list_val)):\n                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n                color_results.putpalette(palette)\n                color_results.save(os.path.join(path_seg_results, name_list_val[i]))\n        else:\n            best_model = best_model\n            best_list = best_list\n            best_result = best_result\n\n\n    return best_list, best_model, best_result\n\n\ndef save_val_best_3d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, path_mask_results, format):\n\n    if eval_list_1[1] < eval_list_2[1]:\n        if best_list[1] < eval_list_2[1]:\n\n            best_model = model2\n            best_list = eval_list_2\n            best_result = 'Result2'\n\n            torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))\n\n        else:\n            best_model = best_model\n            best_list = best_list\n            best_result = best_result\n\n    else:\n        if best_list[1] < eval_list_1[1]:\n\n            best_model = model1\n            best_list = eval_list_1\n            best_result = 'Result1'\n\n            torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))\n\n        else:\n            best_model = best_model\n            best_list = best_list\n            best_result = best_result\n\n    return best_list, best_model, best_result\n\ndef draw_pred_sup(num_classes, mask_train_sup, mask_val, pred_train_sup, outputs_val, train_eval_list, val_eval_list):\n\n\n    mask_image_train_sup = mask_train_sup[0, :, :].data.cpu().numpy()\n    mask_image_val = mask_val[0, :, :].data.cpu().numpy()\n\n    if num_classes == 2:\n        pred_image_train_sup = pred_train_sup[0, 1, :, :].data.cpu().numpy()\n        pred_image_train_sup[pred_image_train_sup > train_eval_list[0]] = 1\n        pred_image_train_sup[pred_image_train_sup <= train_eval_list[0]] = 0\n\n        pred_image_val = outputs_val[0, 1, :, :].data.cpu().numpy()\n        pred_image_val[pred_image_val > val_eval_list[0]] = 1\n        pred_image_val[pred_image_val <= val_eval_list[0]] = 0\n\n    else:\n        pred_image_train_sup = torch.max(pred_train_sup, 1)[1]\n        pred_image_train_sup = pred_image_train_sup[0, :, :].cpu().numpy()\n\n        pred_image_val = torch.max(outputs_val, 1)[1]\n        pred_image_val = pred_image_val[0, :, :].cpu().numpy()\n\n    return mask_image_train_sup, pred_image_train_sup, mask_image_val, pred_image_val\n\n\ndef draw_pred_XNet(num_classes, mask_train, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2):\n\n\n    mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()\n    mask_image_val = mask_val[0, :, :].data.cpu().numpy()\n\n    if num_classes == 2:\n\n        pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()\n        pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1\n        pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0\n\n        pred_image_train_sup2 = pred_train_sup2[0, 1, :, :].data.cpu().numpy()\n        pred_image_train_sup2[pred_image_train_sup2 > train_eval_list2[0]] = 1\n        pred_image_train_sup2[pred_image_train_sup2 <= train_eval_list2[0]] = 0\n\n        pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()\n        pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1\n        pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0\n\n        pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()\n        pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1\n        pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0\n    else:\n\n        pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]\n        pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()\n\n        pred_image_train_sup2 = torch.max(pred_train_sup2, 1)[1]\n        pred_image_train_sup2 = pred_image_train_sup2[0, :, :].cpu().numpy()\n\n        pred_image_val1 = torch.max(outputs_val1, 1)[1]\n        pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()\n\n        pred_image_val2 = torch.max(outputs_val2, 1)[1]\n        pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()\n\n    return mask_image_train_sup, pred_image_train_sup1, pred_image_train_sup2, mask_image_val, pred_image_val1, pred_image_val2\n\ndef draw_pred_MT(num_classes, mask_train, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2):\n\n\n    mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()\n    mask_image_val = mask_val[0, :, :].data.cpu().numpy()\n\n    if num_classes == 2:\n\n        pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()\n        pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1\n        pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0\n\n        pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()\n        pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1\n        pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0\n\n        pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()\n        pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1\n        pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0\n    else:\n\n        pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]\n        pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()\n\n        pred_image_val1 = torch.max(outputs_val1, 1)[1]\n        pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()\n\n        pred_image_val2 = torch.max(outputs_val2, 1)[1]\n        pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()\n\n    return mask_image_train_sup, pred_image_train_sup1, mask_image_val, pred_image_val1, pred_image_val2\n\n\ndef print_best_sup(num_classes, best_val_list, print_num):\n    if num_classes == 2:\n        print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')\n        print('| Best Val  Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')\n        print('| Best Val  Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')\n    else:\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Best Val  Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')\n        print('| Best Val  Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')\n        print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')\n        print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')\n\ndef print_best(num_classes, best_val_list, best_model, best_result, path_trained_model, print_num):\n    if num_classes == 2:\n\n        torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))\n\n        print('| Best  Result: {}'.format(best_result).ljust(print_num, ' '), '|')\n        print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')\n        print('| Best Val  Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')\n        print('| Best Val  Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')\n    else:\n\n        torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))\n\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Best  Result: {}'.format(best_result).ljust(print_num, ' '), '|')\n        print('| Best Val  Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')\n        print('| Best Val  Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')\n        print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')\n        print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')\n\ndef print_test_eval(num_classes, score_list_test, mask_list_test, print_num):\n    if num_classes == 2:\n        eval_list = evaluate(score_list_test, mask_list_test)\n        print('| Test Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Test  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Test  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')\n    else:\n        eval_list = evaluate_multi(score_list_test, mask_list_test)\n        np.set_printoptions(precision=4, suppress=True)\n        print('| Test  Jc: {}  '.format(eval_list[0]).ljust(print_num, ' '), '|')\n        print('| Test  Dc: {}  '.format(eval_list[2]).ljust(print_num, ' '), '|')\n        print('| Test mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')\n        print('| Test mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')\n\n    return eval_list\n\n\ndef save_test_2d(num_classes, score_list_test, name_list_test, threshold, path_seg_results, palette):\n\n    if num_classes == 2:\n        score_list_test = torch.softmax(score_list_test, dim=1)\n        pred_results = score_list_test[:, 1, ...].cpu().numpy()\n        pred_results[pred_results > threshold] = 1\n        pred_results[pred_results <= threshold] = 0\n\n        assert len(name_list_test) == pred_results.shape[0]\n\n        for i in range(len(name_list_test)):\n            color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n            color_results.putpalette(palette)\n            color_results.save(os.path.join(path_seg_results, name_list_test[i]))\n\n    else:\n        pred_results = torch.max(score_list_test, 1)[1]\n        pred_results = pred_results.cpu().numpy()\n\n        assert len(name_list_test) == pred_results.shape[0]\n\n        for i in range(len(name_list_test)):\n            color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')\n            color_results.putpalette(palette)\n            color_results.save(os.path.join(path_seg_results, name_list_test[i]))\n\ndef save_test_3d(num_classes, score_test, name_test, threshold, path_seg_results, affine):\n\n    if num_classes == 2:\n        score_list_test = torch.softmax(score_test, dim=0)\n        pred_results = score_list_test[1, ...].cpu()\n        pred_results[pred_results > threshold] = 1\n        pred_results[pred_results <= threshold] = 0\n\n        pred_results = pred_results.type(torch.uint8)\n\n        output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)\n        output_image.save(os.path.join(path_seg_results, name_test))\n\n    else:\n        pred_results = torch.max(score_test, 0)[1]\n        pred_results = pred_results.cpu()\n        pred_results = pred_results.type(torch.uint8)\n\n        output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)\n        output_image.save(os.path.join(path_seg_results, name_test))\n\n\n\n"
  },
  {
    "path": "config/visdom_config/__init__.py",
    "content": ""
  },
  {
    "path": "config/visdom_config/visual_visdom.py",
    "content": "from visdom import Visdom\nimport os\n\ndef visdom_initialization_sup(env, port):\n    visdom = Visdom(env=env, port=port)\n    visdom.line([0.], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss'], width=550, height=350))\n    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))\n    visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Loss'], width=550, height=350))\n    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))\n    return visdom\n\ndef visualization_sup(vis, epoch, train_loss, train_m_jc, val_loss, val_m_jc):\n    vis.line([train_loss], [epoch], win='train_loss', update='append')\n    vis.line([train_m_jc], [epoch], win='train_jc', update='append')\n    vis.line([val_loss], [epoch], win='val_loss', update='append')\n    vis.line([val_m_jc], [epoch], win='val_jc', update='append')\n\ndef visual_image_sup(vis, mask_train, pred_train, mask_val, pred_val):\n\n    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))\n    vis.heatmap(pred_train, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))\n    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))\n    vis.heatmap(pred_val, win='val_pred1', opts=dict(title='Val Pred', colormap='Viridis'))\n\n\ndef visdom_initialization_XNet(env, port):\n    visdom = Visdom(env=env, port=port)\n    visdom.line([[0., 0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup1', 'Train Sup2', 'Train Unsup'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc1', 'Train Jc2'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))\n    return visdom\n\ndef visualization_XNet(vis, epoch, train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps, train_m_jc1, train_m_jc2, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):\n    vis.line([[train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps]], [epoch], win='train_loss', update='append')\n    vis.line([[train_m_jc1, train_m_jc2]], [epoch], win='train_jc', update='append')\n    vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')\n    vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')\n\ndef visual_image_XNet(vis, mask_train, pred_train1, pred_train2, mask_val, pred_val1, pred_val2):\n\n    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))\n    vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred1', colormap='Viridis'))\n    vis.heatmap(pred_train2, win='train_pred2', opts=dict(title='Train pred2', colormap='Viridis'))\n\n    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))\n    vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))\n    vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))\n\n\ndef visdom_initialization_MT(env, port):\n    visdom = Visdom(env=env, port=port)\n    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))\n    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))\n    return visdom\n\ndef visualization_MT(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):\n    vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')\n    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')\n    vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')\n    vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')\n\ndef visual_image_MT(vis, mask_train, pred_train1, mask_val, pred_val1, pred_val2):\n\n    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))\n    vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))\n    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))\n    vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))\n    vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))\n\n\ndef visdom_initialization_EM(env, port):\n    visdom = Visdom(env=env, port=port)\n    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))\n    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))\n    visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup'], width=550, height=350))\n    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))\n    return visdom\n\ndef visualization_EM(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_m_jc1):\n    vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')\n    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')\n    vis.line([val_loss_sup1], [epoch], win='val_loss', update='append')\n    vis.line([val_m_jc1], [epoch], win='val_jc', update='append')\n\n\ndef visdom_initialization_ConResNet(env, port):\n    visdom = Visdom(env=env, port=port)\n    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Seg', 'Train Res'], width=550, height=350))\n    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))\n    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Seg', 'Val Res'], width=550, height=350))\n    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))\n    return visdom\n\ndef visualization_ConResNet(vis, epoch, train_loss, train_loss_seg, train_loss_res, train_m_jc1, val_loss_seg, val_loss_res, val_m_jc1):\n    vis.line([[train_loss, train_loss_seg, train_loss_res]], [epoch], win='train_loss', update='append')\n    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')\n    vis.line([[val_loss_seg, val_loss_res]], [epoch], win='val_loss', update='append')\n    vis.line([val_m_jc1], [epoch], win='val_jc', update='append')"
  },
  {
    "path": "config/warmup_config/__init__.py",
    "content": ""
  },
  {
    "path": "config/warmup_config/warmup.py",
    "content": "from torch.optim.lr_scheduler import _LRScheduler\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\n\nclass GradualWarmupScheduler(_LRScheduler):\n    \"\"\" Gradually warm-up(increasing) learning rate in optimizer.\n    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.\n        total_epoch: target learning rate is reached at total_epoch, gradually\n        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)\n    \"\"\"\n\n    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):\n        self.multiplier = multiplier\n        if self.multiplier < 1.:\n            raise ValueError('multiplier should be greater thant or equal to 1.')\n        self.total_epoch = total_epoch\n        self.after_scheduler = after_scheduler\n        self.finished = False\n        super(GradualWarmupScheduler, self).__init__(optimizer)\n\n    def get_lr(self):\n        if self.last_epoch > self.total_epoch:\n            if self.after_scheduler:\n                if not self.finished:\n                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]\n                    self.finished = True\n                return self.after_scheduler.get_last_lr()\n            return [base_lr * self.multiplier for base_lr in self.base_lrs]\n\n        if self.multiplier == 1.0:\n            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]\n        else:\n            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]\n\n    def step_ReduceLROnPlateau(self, metrics, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning\n        if self.last_epoch <= self.total_epoch:\n            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]\n            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):\n                param_group['lr'] = lr\n        else:\n            if epoch is None:\n                self.after_scheduler.step(metrics, None)\n            else:\n                self.after_scheduler.step(metrics, epoch - self.total_epoch)\n\n    def step(self, epoch=None, metrics=None):\n        if type(self.after_scheduler) != ReduceLROnPlateau:\n            if self.finished and self.after_scheduler:\n                if epoch is None:\n                    self.after_scheduler.step(None)\n                else:\n                    self.after_scheduler.step(epoch - self.total_epoch)\n                self._last_lr = self.after_scheduler.get_last_lr()\n            else:\n                return super(GradualWarmupScheduler, self).step(epoch)\n        else:\n            self.step_ReduceLROnPlateau(metrics, epoch)"
  },
  {
    "path": "dataload/__init__.py",
    "content": ""
  },
  {
    "path": "dataload/dataset_2d.py",
    "content": "import os\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom PIL import Image\nimport cv2\nimport numpy as np\nimport pywt\n\nclass dataset_itn(Dataset):\n    def __init__(self, data_dir, input1, augmentation_1, normalize_1, sup=True, num_images=None, **kwargs):\n        super(dataset_itn, self).__init__()\n\n        img_paths_1 = []\n        mask_paths = []\n\n        image_dir_1 = data_dir + '/' + input1\n        if sup:\n            mask_dir = data_dir + '/mask'\n\n        for image in os.listdir(image_dir_1):\n\n            image_path_1 = os.path.join(image_dir_1, image)\n            img_paths_1.append(image_path_1)\n\n            if sup:\n                mask_path = os.path.join(mask_dir, image)\n                mask_paths.append(mask_path)\n\n        if sup:\n            assert len(img_paths_1) == len(mask_paths)\n\n        if num_images is not None:\n            len_img_paths = len(img_paths_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                img_paths_1 = img_paths_1[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                img_paths_1 = img_paths_1 * quotient\n                img_paths_1 += [img_paths_1[i] for i in new_indices]\n\n                if sup:\n                    mask_paths = mask_paths * quotient\n                    mask_paths += [mask_paths[i] for i in new_indices]\n\n        self.img_paths_1 = img_paths_1\n        self.mask_paths = mask_paths\n        self.augmentation_1 = augmentation_1\n        self.normalize_1 = normalize_1\n        self.sup = sup\n        self.kwargs = kwargs\n\n    def __getitem__(self, index):\n\n        img_path_1 = self.img_paths_1[index]\n        img_1 = Image.open(img_path_1)\n        img_1 = np.array(img_1)\n\n        if self.sup:\n            mask_path = self.mask_paths[index]\n            mask = Image.open(mask_path)\n            mask = np.array(mask)\n\n            augment_1 = self.augmentation_1(image=img_1, mask=mask)\n            img_1 = augment_1['image']\n            mask_1 = augment_1['mask']\n\n            normalize_1 = self.normalize_1(image=img_1, mask=mask_1)\n            img_1 = normalize_1['image']\n            mask_1 = normalize_1['mask']\n            mask_1 = mask_1.long()\n\n            sampel = {'image': img_1, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}\n\n        else:\n            augment_1 = self.augmentation_1(image=img_1)\n            img_1 = augment_1['image']\n            normalize_1 = self.normalize_1(image=img_1)\n            img_1 = normalize_1['image']\n\n            sampel = {'image': img_1, 'ID': os.path.split(img_path_1)[1]}\n\n        return sampel\n\n    def __len__(self):\n        return len(self.img_paths_1)\n\n\ndef imagefloder_itn(data_dir, input1, data_transform_1, data_normalize_1, sup=True, num_images=None, **kwargs):\n    dataset = dataset_itn(data_dir=data_dir,\n                           input1=input1,\n                           augmentation_1=data_transform_1,\n                           normalize_1=data_normalize_1,\n                           sup=sup,\n                           num_images=num_images,\n                           **kwargs\n                           )\n    return dataset\n\n\nclass dataset_iitnn(Dataset):\n    def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True,\n                 num_images=None, **kwargs):\n        super(dataset_iitnn, self).__init__()\n\n        img_paths_1 = []\n        img_paths_2 = []\n        mask_paths = []\n\n        image_dir_1 = data_dir + '/' + input1\n        image_dir_2 = data_dir + '/' + input2\n        if sup:\n            mask_dir = data_dir + '/mask'\n\n        for image in os.listdir(image_dir_1):\n\n            image_path_1 = os.path.join(image_dir_1, image)\n            img_paths_1.append(image_path_1)\n\n            image_path_2 = os.path.join(image_dir_2, image)\n            img_paths_2.append(image_path_2)\n\n            if sup:\n                mask_path = os.path.join(mask_dir, image)\n                mask_paths.append(mask_path)\n\n        assert len(img_paths_1) == len(img_paths_2)\n        if sup:\n            assert len(img_paths_1) == len(mask_paths)\n\n        if num_images is not None:\n            len_img_paths = len(img_paths_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                img_paths_1 = img_paths_1[:num_images]\n                img_paths_2 = img_paths_2[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                img_paths_1 = img_paths_1 * quotient\n                img_paths_1 += [img_paths_1[i] for i in new_indices]\n                img_paths_2 = img_paths_2 * quotient\n                img_paths_2 += [img_paths_2[i] for i in new_indices]\n\n                if sup:\n                    mask_paths = mask_paths * quotient\n                    mask_paths += [mask_paths[i] for i in new_indices]\n\n        self.img_paths_1 = img_paths_1\n        self.img_paths_2 = img_paths_2\n        self.mask_paths = mask_paths\n        self.augmentation_1 = augmentation1\n        self.normalize_1 = normalize_1\n        self.normalize_2 = normalize_2\n        self.sup = sup\n        self.kwargs = kwargs\n\n    def __getitem__(self, index):\n\n        img_path_1 = self.img_paths_1[index]\n        img_1 = Image.open(img_path_1)\n        img_1 = np.array(img_1)\n\n        img_path_2 = self.img_paths_2[index]\n        img_2 = Image.open(img_path_2)\n        img_2 = np.array(img_2)\n\n        if self.sup:\n            mask_path = self.mask_paths[index]\n            mask = Image.open(mask_path)\n            mask = np.array(mask)\n\n            augment_1 = self.augmentation_1(image=img_1, image2=img_2, mask=mask)\n            img_1 = augment_1['image']\n            img_2 = augment_1['image2']\n            mask = augment_1['mask']\n\n            normalize_1 = self.normalize_1(image=img_1, mask=mask)\n            img_1 = normalize_1['image']\n            mask = normalize_1['mask']\n            mask = mask.long()\n\n            normalize_2 = self.normalize_2(image=img_2)\n            img_2 = normalize_2['image']\n\n            sampel = {'image': img_1, 'image_2': img_2, 'mask': mask, 'ID': os.path.split(mask_path)[1]}\n\n        else:\n            augment_1 = self.augmentation_1(image=img_1, image2=img_2)\n            img_1 = augment_1['image']\n            img_2 = augment_1['image2']\n\n            normalize_1 = self.normalize_1(image=img_1)\n            img_1 = normalize_1['image']\n\n            normalize_2 = self.normalize_2(image=img_2)\n            img_2 = normalize_2['image']\n\n            sampel = {'image': img_1, 'image_2': img_2, 'ID': os.path.split(img_path_1)[1]}\n\n        return sampel\n\n    def __len__(self):\n        return len(self.img_paths_1)\n\n\ndef imagefloder_iitnn(data_dir, input1, input2, data_transform_1, data_normalize_1, data_normalize_2, sup=True, num_images=None, **kwargs):\n    dataset = dataset_iitnn(data_dir=data_dir,\n                           input1=input1,\n                           input2=input2,\n                           augmentation1=data_transform_1,\n                           normalize_1=data_normalize_1,\n                           normalize_2=data_normalize_2,\n                           sup=sup,\n                           num_images=num_images,\n                           **kwargs\n                           )\n    return dataset\n\nclass dataset_wds(Dataset):\n    def __init__(self, data_dir, augmentation1, normalize_LL, normalize_LH, normalize_HL, normalize_HH, **kwargs):\n        super(dataset_wds, self).__init__()\n\n        img_paths_LL = []\n        img_paths_LH = []\n        img_paths_HL = []\n        img_paths_HH = []\n        mask_paths = []\n        image_dir_LL = data_dir + '/LL'\n        image_dir_LH = data_dir + '/LH'\n        image_dir_HL = data_dir + '/HL'\n        image_dir_HH = data_dir + '/HH'\n        mask_dir = data_dir + '/mask'\n\n        for image in os.listdir(image_dir_LL):\n\n            image_path_LL = os.path.join(image_dir_LL, image)\n            img_paths_LL.append(image_path_LL)\n            image_path_LH = os.path.join(image_dir_LH, image)\n            img_paths_LH.append(image_path_LH)\n            image_path_HL = os.path.join(image_dir_HL, image)\n            img_paths_HL.append(image_path_HL)\n            image_path_HH = os.path.join(image_dir_HH, image)\n            img_paths_HH.append(image_path_HH)\n\n            mask_path = os.path.join(mask_dir, image)\n            mask_paths.append(mask_path)\n\n        self.img_paths_LL = img_paths_LL\n        self.img_paths_LH = img_paths_LH\n        self.img_paths_HL = img_paths_HL\n        self.img_paths_HH = img_paths_HH\n        self.mask_paths = mask_paths\n        self.augmentation_1 = augmentation1\n        self.normalize_LL = normalize_LL\n        self.normalize_LH = normalize_LH\n        self.normalize_HL = normalize_HL\n        self.normalize_HH = normalize_HH\n        self.kwargs = kwargs\n\n    def __getitem__(self, index):\n\n        img_path_LL = self.img_paths_LL[index]\n        img_LL = Image.open(img_path_LL)\n        img_LL = np.array(img_LL)\n\n        img_path_LH = self.img_paths_LH[index]\n        img_LH = Image.open(img_path_LH)\n        img_LH = np.array(img_LH)\n\n        img_path_HL = self.img_paths_HL[index]\n        img_HL = Image.open(img_path_HL)\n        img_HL = np.array(img_HL)\n\n        img_path_HH = self.img_paths_HH[index]\n        img_HH = Image.open(img_path_HH)\n        img_HH = np.array(img_HH)\n\n        mask_path = self.mask_paths[index]\n        mask = Image.open(mask_path)\n        mask = np.array(mask)\n\n        augment_1 = self.augmentation_1(image=img_LL, mask=mask, imageLH=img_LH, imageHL=img_HL, imageHH=img_HH)\n        img_LL = augment_1['image']\n        img_LH = augment_1['imageLH']\n        img_HL = augment_1['imageHL']\n        img_HH = augment_1['imageHH']\n        mask_1 = augment_1['mask']\n\n        normalize_LL = self.normalize_LL(image=img_LL, mask=mask_1)\n        img_LL = normalize_LL['image']\n        mask_1 = normalize_LL['mask']\n        mask_1 = mask_1.long()\n\n        normalize_LH = self.normalize_LH(image=img_LH)\n        img_LH = normalize_LH['image']\n\n        normalize_HL = self.normalize_HL(image=img_HL)\n        img_HL = normalize_HL['image']\n\n        normalize_HH = self.normalize_HH(image=img_HH)\n        img_HH = normalize_HH['image']\n\n        sampel = {'image_LL': img_LL, 'image_LH': img_LH, 'image_HL': img_HL,  'image_HH': img_HH, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}\n\n        return sampel\n\n    def __len__(self):\n        return len(self.img_paths_LL)\n\n\ndef imagefloder_wds(data_dir, data_transform_1, data_normalize_LL, data_normalize_LH, data_normalize_HL, data_normalize_HH, **kwargs):\n    dataset = dataset_wds(data_dir=data_dir,\n                           augmentation1=data_transform_1,\n                           normalize_LL=data_normalize_LL,\n                           normalize_LH=data_normalize_LH,\n                           normalize_HL=data_normalize_HL,\n                           normalize_HH=data_normalize_HH,\n                           **kwargs\n                           )\n    return dataset\n\nclass dataset_aerial_lanenet(Dataset):\n    def __init__(self, data_dir, augmentation1, normalize_1, normalize_l1, normalize_l2, normalize_l3, normalize_l4, **kwargs):\n        super(dataset_aerial_lanenet, self).__init__()\n\n        img_paths = []\n        mask_paths = []\n        image_dir = data_dir + '/image'\n        mask_dir = data_dir + '/mask'\n\n        for image in os.listdir(image_dir):\n\n            image_path = os.path.join(image_dir, image)\n            img_paths.append(image_path)\n\n            mask_path = os.path.join(mask_dir, image)\n            mask_paths.append(mask_path)\n\n        self.img_paths = img_paths\n        self.mask_paths = mask_paths\n        self.augmentation_1 = augmentation1\n        self.normalize_1 = normalize_1\n        self.normalize_l4 = normalize_l4\n        self.normalize_l3 = normalize_l3\n        self.normalize_l2 = normalize_l2\n        self.normalize_l1 = normalize_l1\n        self.kwargs = kwargs\n\n    def __getitem__(self, index):\n\n        img_path = self.img_paths[index]\n        img = Image.open(img_path)\n        img = np.array(img)\n\n        mask_path = self.mask_paths[index]\n        mask = Image.open(mask_path)\n        mask = np.array(mask)\n\n        augment_1 = self.augmentation_1(image=img, mask=mask)\n        img = augment_1['image']\n        mask = augment_1['mask']\n\n        img_ = np.array(Image.fromarray(img).convert('L'))\n        _, l4, l3, l2, l1 = pywt.wavedec2(img_, 'db2', level=4)\n\n        l4 = np.array(l4).transpose(1, 2, 0)\n        l3 = np.array(l3).transpose(1, 2, 0)\n        l2 = np.array(l2).transpose(1, 2, 0)\n        l1 = np.array(l1).transpose(1, 2, 0)\n        normalize_l4 = self.normalize_l4(image=l4)\n        l4 = normalize_l4['image'].float()\n        normalize_l3 = self.normalize_l3(image=l3)\n        l3 = normalize_l3['image'].float()\n        normalize_l2 = self.normalize_l2(image=l2)\n        l2 = normalize_l2['image'].float()\n        normalize_l1 = self.normalize_l1(image=l1)\n        l1 = normalize_l1['image'].float()\n\n        normalize_1 = self.normalize_1(image=img, mask=mask)\n        img = normalize_1['image']\n        mask = normalize_1['mask'].long()\n\n        sampel = {'image': img, 'image_l1': l1, 'image_l2': l2, 'image_l3': l3, 'image_l4': l4, 'mask': mask, 'ID': os.path.split(mask_path)[1]}\n\n        return sampel\n\n    def __len__(self):\n        return len(self.img_paths)\n\n\ndef imagefloder_aerial_lanenet(data_dir, data_transform, data_normalize, data_normalize_l1, data_normalize_l2, data_normalize_l3, data_normalize_l4, **kwargs):\n    dataset = dataset_aerial_lanenet(data_dir=data_dir,\n                           augmentation1=data_transform,\n                           normalize_1=data_normalize,\n                           normalize_l1=data_normalize_l1,\n                           normalize_l2=data_normalize_l2,\n                           normalize_l3=data_normalize_l3,\n                           normalize_l4=data_normalize_l4,\n                           **kwargs\n                           )\n    return dataset"
  },
  {
    "path": "dataload/dataset_3d.py",
    "content": "import os\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom PIL import Image\nimport cv2\nimport numpy as np\nimport torchio as tio\nimport SimpleITK as sitk\nfrom torchio.data import UniformSampler, LabelSampler\n\n\nclass dataset_it(Dataset):\n    def __init__(self, data_dir, input1, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):\n        super(dataset_it, self).__init__()\n\n        self.subjects_1 = []\n\n        image_dir_1 = data_dir + '/' + input1\n        if sup:\n            mask_dir = data_dir + '/mask'\n\n        for i in os.listdir(image_dir_1):\n            image_path_1 = os.path.join(image_dir_1, i)\n            if sup:\n                mask_path = os.path.join(mask_dir, i)\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path), ID=i)\n            else:\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)\n\n            self.subjects_1.append(subject_1)\n\n        if num_images is not None:\n            len_img_paths = len(self.subjects_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                self.subjects_1 = self.subjects_1[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                self.subjects_1 = self.subjects_1 * quotient\n                self.subjects_1 += [self.subjects_1[i] for i in new_indices]\n\n        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)\n\n        self.queue_train_set_1 = tio.Queue(\n            subjects_dataset=self.dataset_1,\n            max_length=queue_length,\n            samples_per_volume=samples_per_volume,\n            sampler=UniformSampler(patch_size),\n            # sampler=LabelSampler(patch_size),\n            num_workers=num_workers,\n            shuffle_subjects=shuffle_subjects,\n            shuffle_patches=shuffle_patches\n        )\n\n\nclass dataset_it_dtc(Dataset):\n    def __init__(self, data_dir, input1, num_classes, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):\n        super(dataset_it_dtc, self).__init__()\n\n        self.subjects_1 = []\n\n        image_dir_1 = data_dir + '/' + input1\n        if sup:\n            mask_dir_1 = data_dir + '/mask'\n            mask_dir_2 = data_dir + '/mask_sdf1'\n            if num_classes == 3:\n                mask_dir_3 = data_dir + '/mask_sdf2'\n\n        for i in os.listdir(image_dir_1):\n            image_path_1 = os.path.join(image_dir_1, i)\n            if sup:\n                mask_path_1 = os.path.join(mask_dir_1, i)\n                mask_path_2 = os.path.join(mask_dir_2, i)\n                if num_classes == 3:\n                    mask_path_3 = os.path.join(mask_dir_3, i)\n                    subject_1 = tio.Subject(\n                        image=tio.ScalarImage(image_path_1),\n                        mask=tio.LabelMap(mask_path_1),\n                        mask2=tio.LabelMap(mask_path_2),\n                        mask3=tio.LabelMap(mask_path_3),\n                        ID=i)\n                else:\n                    subject_1 = tio.Subject(\n                        image=tio.ScalarImage(image_path_1),\n                        mask=tio.LabelMap(mask_path_1),\n                        mask2=tio.LabelMap(mask_path_2),\n                        ID=i)\n            else:\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)\n\n            self.subjects_1.append(subject_1)\n\n        if num_images is not None:\n            len_img_paths = len(self.subjects_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                self.subjects_1 = self.subjects_1[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                self.subjects_1 = self.subjects_1 * quotient\n                self.subjects_1 += [self.subjects_1[i] for i in new_indices]\n\n        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)\n\n        self.queue_train_set_1 = tio.Queue(\n            subjects_dataset=self.dataset_1,\n            max_length=queue_length,\n            samples_per_volume=samples_per_volume,\n            sampler=UniformSampler(patch_size),\n            # sampler=LabelSampler(patch_size),\n            num_workers=num_workers,\n            shuffle_subjects=shuffle_subjects,\n            shuffle_patches=shuffle_patches\n        )\n\nclass dataset_iit(Dataset):\n    def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):\n        super(dataset_iit, self).__init__()\n\n        self.subjects_1 = []\n\n        image_dir_1 = data_dir + '/' + input1\n        image_dir_2 = data_dir + '/' + input2\n\n        if sup:\n            mask_dir_1 = data_dir + '/mask'\n\n        for i in os.listdir(image_dir_1):\n            image_path_1 = os.path.join(image_dir_1, i)\n            image_path_2 = os.path.join(image_dir_2, i)\n            if sup:\n                mask_path_1 = os.path.join(mask_dir_1, i)\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), ID=i)\n            else:\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)\n\n            self.subjects_1.append(subject_1)\n\n        if num_images is not None:\n            len_img_paths = len(self.subjects_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                self.subjects_1 = self.subjects_1[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                self.subjects_1 = self.subjects_1 * quotient\n                self.subjects_1 += [self.subjects_1[i] for i in new_indices]\n\n        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)\n\n        self.queue_train_set_1 = tio.Queue(\n            subjects_dataset=self.dataset_1,\n            max_length=queue_length,\n            samples_per_volume=samples_per_volume,\n            sampler=UniformSampler(patch_size),\n            # sampler=LabelSampler(patch_size),\n            num_workers=num_workers,\n            shuffle_subjects=shuffle_subjects,\n            shuffle_patches=shuffle_patches\n        )\n\n\nclass dataset_iit_conresnet(Dataset):\n    def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):\n        super(dataset_iit_conresnet, self).__init__()\n\n        self.subjects_1 = []\n\n        image_dir_1 = data_dir + '/' + input1\n        image_dir_2 = data_dir + '/' + input2\n\n        if sup:\n            mask_dir_1 = data_dir + '/mask'\n            mask_dir_2 = data_dir + '/mask_res'\n\n        for i in os.listdir(image_dir_1):\n            image_path_1 = os.path.join(image_dir_1, i)\n            image_path_2 = os.path.join(image_dir_2, i)\n            if sup:\n                mask_path_1 = os.path.join(mask_dir_1, i)\n                mask_path_2 = os.path.join(mask_dir_2, i)\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), ID=i)\n            else:\n                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)\n\n            self.subjects_1.append(subject_1)\n\n        if num_images is not None:\n            len_img_paths = len(self.subjects_1)\n            quotient = num_images // len_img_paths\n            remainder = num_images % len_img_paths\n\n            if num_images <= len_img_paths:\n                self.subjects_1 = self.subjects_1[:num_images]\n            else:\n                rand_indices = torch.randperm(len_img_paths).tolist()\n                new_indices = rand_indices[:remainder]\n\n                self.subjects_1 = self.subjects_1 * quotient\n                self.subjects_1 += [self.subjects_1[i] for i in new_indices]\n\n        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)\n\n        self.queue_train_set_1 = tio.Queue(\n            subjects_dataset=self.dataset_1,\n            max_length=queue_length,\n            samples_per_volume=samples_per_volume,\n            sampler=UniformSampler(patch_size),\n            # sampler=LabelSampler(patch_size),\n            num_workers=num_workers,\n            shuffle_subjects=shuffle_subjects,\n            shuffle_patches=shuffle_patches\n        )\n\n"
  },
  {
    "path": "loss/__init__.py",
    "content": ""
  },
  {
    "path": "loss/loss_function.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.autograd import Variable\nimport sys\nfrom torch.nn.modules.loss import _Loss\n\nclass MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):\n    def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):\n        super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index)\n        self.aux = aux\n        self.aux_weight = aux_weight\n\n    def _aux_forward(self, output, target, **kwargs):\n        # *preds, target = tuple(inputs)\n\n        loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[0], target)\n        for i in range(1, len(output)):\n            aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[i], target)\n            loss += self.aux_weight * aux_loss\n        return loss\n\n    def forward(self, output, target):\n        # preds, target = tuple(inputs)\n        # inputs = tuple(list(preds) + [target])\n        if self.aux:\n            return self._aux_forward(output, target)\n        else:\n            return super(MixSoftmaxCrossEntropyLoss, self).forward(output, target)\n\nclass BinaryDiceLoss(nn.Module):\n    \"\"\"Dice loss of binary class\n    Args:\n        smooth: A float number to smooth loss, and avoid NaN error, default: 1\n        p: Denominator value: \\sum{x^p} + \\sum{y^p}, default: 2\n        predict: A tensor of shape [N, *]\n        target: A tensor of shape same with predict\n        reduction: Reduction method to apply, return mean over batch if 'mean',\n            return sum if 'sum', return a tensor of shape [N,] if 'none'\n    Returns:\n        Loss tensor according to arg reduction\n    Raise:\n        Exception if unexpected reduction\n    \"\"\"\n\n    def __init__(self, smooth=1, p=2, reduction='mean'):\n        super(BinaryDiceLoss, self).__init__()\n        self.smooth = smooth\n        self.p = p\n        self.reduction = reduction\n\n    def forward(self, predict, target, valid_mask):\n        assert predict.shape[0] == target.shape[0], \"predict & target batch size don't match\"\n        predict = predict.contiguous().view(predict.shape[0], -1)\n        target = target.contiguous().view(target.shape[0], -1).float()\n        valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1).float()\n\n        num = torch.sum(torch.mul(predict, target) * valid_mask, dim=1) * 2 + self.smooth\n        den = torch.sum((predict.pow(self.p) + target.pow(self.p)) * valid_mask, dim=1) + self.smooth\n\n        loss = 1 - num / den\n\n        if self.reduction == 'mean':\n            return loss.mean()\n        elif self.reduction == 'sum':\n            return loss.sum()\n        elif self.reduction == 'none':\n            return loss\n        else:\n            raise Exception('Unexpected reduction {}'.format(self.reduction))\n\n\nclass DiceLoss(nn.Module):\n    \"\"\"Dice loss, need one hot encode input\"\"\"\n\n    def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs):\n        super(DiceLoss, self).__init__()\n        self.kwargs = kwargs\n        self.weight = weight\n        self.ignore_index = ignore_index\n        self.aux = aux\n        self.aux_weight = aux_weight\n\n    def _base_forward(self, predict, target, valid_mask):\n\n        dice = BinaryDiceLoss(**self.kwargs)\n        total_loss = 0\n        predict = F.softmax(predict, dim=1)\n\n        for i in range(target.shape[-1]):\n            if i != self.ignore_index:\n                dice_loss = dice(predict[:, i], target[..., i], valid_mask)\n                if self.weight is not None:\n                    assert self.weight.shape[0] == target.shape[1], \\\n                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])\n                    dice_loss *= self.weights[i]\n                total_loss += dice_loss\n\n        return total_loss / target.shape[-1]\n\n    def _aux_forward(self, output, target, **kwargs):\n        # *preds, target = tuple(inputs)\n        valid_mask = (target != self.ignore_index).long()\n        target_one_hot = F.one_hot(torch.clamp_min(target, 0))\n        loss = self._base_forward(output[0], target_one_hot, valid_mask)\n        for i in range(1, len(output)):\n            aux_loss = self._base_forward(output[i], target_one_hot, valid_mask)\n            loss += self.aux_weight * aux_loss\n        return loss\n\n    def forward(self, output, target):\n        # preds, target = tuple(inputs)\n        # inputs = tuple(list(preds) + [target])\n        if self.aux:\n            return self._aux_forward(output, target)\n        else:\n            valid_mask = (target != self.ignore_index).long()\n            target_one_hot = F.one_hot(torch.clamp_min(target, 0))\n            return self._base_forward(output, target_one_hot, valid_mask)\n\n\ndef softmax_mse_loss(input_logits, target_logits, sigmoid=False):\n    \"\"\"Takes softmax on both sides and returns MSE loss\n    Note:\n    - Returns the sum over all examples. Divide by the batch size afterwards\n      if you want the mean.\n    - Sends gradients to inputs but not the targets.\n    \"\"\"\n    assert input_logits.size() == target_logits.size()\n    if sigmoid:\n        input_softmax = torch.sigmoid(input_logits)\n        target_softmax = torch.sigmoid(target_logits)\n    else:\n        input_softmax = F.softmax(input_logits, dim=1)\n        target_softmax = F.softmax(target_logits, dim=1)\n\n    mse_loss = (input_softmax-target_softmax)**2\n    return mse_loss\n\n\ndef entropy_loss(p, C=2):\n    # p N*C*W*H*D\n    y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / torch.tensor(np.log(C)).cuda()\n    ent = torch.mean(y1)\n\n    return ent\n\nclass BCELossBoud(nn.Module):\n    def __init__(self, num_classes, weight=None, ignore_index=None, **kwargs):\n        super(BCELossBoud, self).__init__()\n        self.kwargs = kwargs\n        self.weight = weight\n        self.ignore_index = ignore_index\n        self.num_classes = num_classes\n        self.criterion = nn.BCEWithLogitsLoss()\n\n    def weighted_BCE_cross_entropy(self, output, target, weights = None):\n        if weights is not None:\n            assert len(weights) == 2\n            output = torch.clamp(output, min=1e-3, max=1-1e-3)\n            bce = weights[1] * (target * torch.log(output)) + weights[0] * ((1-target) * torch.log((1-output)))\n        else:\n            output = torch.clamp(output, min=1e-3, max=1 - 1e-3)\n            bce = target * torch.log(output) + (1-target) * torch.log((1-output))\n        return torch.neg(torch.mean(bce))\n\n    def forward(self, predict, target):\n\n        target_one_hot = F.one_hot(torch.clamp_min(target, 0), num_classes=self.num_classes).permute(0, 4, 1, 2, 3)\n        predict = torch.softmax(predict, 1)\n\n        bs, category, depth, width, heigt = target_one_hot.shape\n        bce_loss = []\n        for i in range(predict.shape[1]):\n            pred_i = predict[:,i]\n            targ_i = target_one_hot[:,i]\n            tt = np.log(depth * width * heigt / (target_one_hot[:, i].cpu().data.numpy().sum()+1))\n            bce_i = self.weighted_BCE_cross_entropy(pred_i, targ_i, weights=[1, tt])\n            bce_loss.append(bce_i)\n\n        bce_loss = torch.stack(bce_loss)\n        total_loss = bce_loss.mean()\n        return total_loss\n\n\nclass CustomKLLoss(_Loss):\n    '''\n    KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N\n    N is the total number of image voxels\n    '''\n\n    def __init__(self, *args, **kwargs):\n        super(CustomKLLoss, self).__init__()\n\n    def forward(self, mean, std):\n        return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean(\n            torch.log(torch.mul(std, std))) - 1\n\n\ndef segmentation_loss(loss='CE', aux=False, **kwargs):\n\n    if loss == 'dice' or loss == 'DICE':\n        seg_loss = DiceLoss(aux=aux)\n    elif loss == 'crossentropy' or loss == 'CE':\n        seg_loss = MixSoftmaxCrossEntropyLoss(aux=aux)\n    elif loss == 'bce':\n        seg_loss = nn.BCELoss(size_average=True)\n    elif loss == 'bcebound':\n        seg_loss = BCELossBoud(num_classes=kwargs['num_classes'])\n    else:\n        print('sorry, the loss you input is not supported yet')\n        sys.exit()\n\n    return seg_loss\n\n\n# if __name__ == '__main__':\n#     from models import *\n#     criterion = segmentation_loss(loss='LOVASZ')\n#     # criterion = nn.CrossEntropyLoss()\n#\n#     model = unet(1, 2)\n#     model.eval()\n#     input = torch.rand(3, 1, 128, 128)\n#     mask = torch.zeros(3, 128, 128).long()\n#\n#     mask[:, 40:100, 30:60] = 1\n#     output = model(input)\n#\n#     loss = criterion(output, mask)\n#     print(loss)\n#     # loss.requires_grad_(True)\n#     # loss.backward()\n\n"
  },
  {
    "path": "models/__init__.py",
    "content": "# 2d\nfrom .networks_2d.xnet import XNet, XNet_1_1_m, XNet_1_2_m, XNet_2_1_m, XNet_3_2_m, XNet_2_3_m, XNet_3_3_m, XNet_sb\nfrom .networks_2d.unet import unet, r2_unet, attention_unet\nfrom .networks_2d.unet_plusplus import unet_plusplus\nfrom .networks_2d.hrnet import hrnet18, hrnet32, hrnet48, hrnet64\nfrom .networks_2d.swinunet import swinunet\nfrom .networks_2d.unet_urpc import unet_urpc\nfrom .networks_2d.unet_cct import unet_cct\nfrom .networks_2d.resunet import res_unet\nfrom .networks_2d.resunet_plusplus import res_unet_plusplus\nfrom .networks_2d.u2net import u2net, u2net_small\nfrom .networks_2d.unet_3plus import unet_3plus, unet_3plus_ds, unet_3plus_ds_cgm\nfrom .networks_2d.wavesnet import wsegnet_vgg16_bn\nfrom .networks_2d.mwcnn import mwcnn\nfrom .networks_2d.aerial_lanenet import Aerial_LaneNet\nfrom .networks_2d.wds import WDS\n\n# 3d\nfrom .networks_3d.unet3d import unet3d, unet3d_min\nfrom .networks_3d.vnet import vnet\nfrom .networks_3d.res_unet3d import res_unet3d\nfrom .networks_3d.transbts import transbts\nfrom .networks_3d.cotr import cotr\nfrom .networks_3d.dmfnet import dmfnet\nfrom .networks_3d.conresnet import conresnet\nfrom .networks_3d.espnet3d import espnet3d\nfrom .networks_3d.unetr import unertr\nfrom .networks_3d.unet3d_urpc import unet3d_urpc\nfrom .networks_3d.unet3d_cct import unet3d_cct, unet3d_cct_min\nfrom .networks_3d.unet3d_dtc import unet3d_dtc\nfrom .networks_3d.xnet3d import xnet3d\nfrom .networks_3d.vnet_cct import vnet_cct\nfrom .networks_3d.vnet_dtc import vnet_dtc"
  },
  {
    "path": "models/getnetwork.py",
    "content": "import sys\nfrom models import *\nimport torch.nn as nn\n\ndef get_network(network, in_channels, num_classes, **kwargs):\n\n    # 2d networks\n    if network == 'xnet':\n        net = XNet(in_channels, num_classes)\n    elif network == 'xnet_sb':\n        net = XNet_sb(in_channels, num_classes)\n    elif network == 'xnet_1_1_m':\n        net = XNet_1_1_m(in_channels, num_classes)\n    elif network == 'xnet_1_2_m':\n        net = XNet_1_2_m(in_channels, num_classes)\n    elif network == 'xnet_2_1_m':\n        net = XNet_2_1_m(in_channels, num_classes)\n    elif network == 'xnet_3_2_m':\n        net = XNet_3_2_m(in_channels, num_classes)\n    elif network == 'xnet_2_3_m':\n        net = XNet_2_3_m(in_channels, num_classes)\n    elif network == 'xnet_3_3_m':\n        net = XNet_3_3_m(in_channels, num_classes)\n    elif network == 'unet':\n        net = unet(in_channels, num_classes)\n    elif network == 'unet_plusplus' or network == 'unet++':\n        net = unet_plusplus(in_channels, num_classes)\n    elif network == 'r2unet':\n        net = r2_unet(in_channels, num_classes)\n    elif network == 'attunet':\n        net = attention_unet(in_channels, num_classes)\n    elif network == 'hrnet18':\n        net = hrnet18(in_channels, num_classes)\n    elif network == 'hrnet48':\n        net = hrnet48(in_channels, num_classes)\n    elif network == 'resunet':\n        net = res_unet(in_channels, num_classes)\n    elif network == 'resunet++':\n        net = res_unet_plusplus(in_channels, num_classes)\n    elif network == 'u2net':\n        net = u2net(in_channels, num_classes)\n    elif network == 'u2net_s':\n        net = u2net_small(in_channels, num_classes)\n    elif network == 'unet3+':\n        net = unet_3plus(in_channels, num_classes)\n    elif network == 'unet3+_ds':\n        net = unet_3plus_ds(in_channels, num_classes)\n    elif network == 'unet3+_ds_cgm':\n        net = unet_3plus_ds_cgm(in_channels, num_classes)\n    elif network == 'swinunet':\n        net = swinunet(num_classes, 224)  # img_size = 224\n    elif network == 'unet_urpc':\n        net = unet_urpc(in_channels, num_classes)\n    elif network == 'unet_cct':\n        net = unet_cct(in_channels, num_classes)\n    elif network == 'wavesnet':\n        net = wsegnet_vgg16_bn(in_channels, num_classes)\n    elif network == 'mwcnn':\n        net = mwcnn(in_channels, num_classes)\n    elif network == 'alnet':\n        net = Aerial_LaneNet(in_channels, num_classes)\n    elif network == 'wds':\n        net = WDS(in_channels, num_classes)\n\n    # 3d networks\n    elif network == 'xnet3d':\n        net = xnet3d(in_channels, num_classes)\n    elif network == 'unet3d':\n        net = unet3d(in_channels, num_classes)\n    elif network == 'unet3d_min':\n        net = unet3d_min(in_channels, num_classes)\n    elif network == 'unet3d_urpc':\n        net = unet3d_urpc(in_channels, num_classes)\n    elif network == 'unet3d_cct':\n        net = unet3d_cct(in_channels, num_classes)\n    elif network == 'unet3d_cct_min':\n        net = unet3d_cct_min(in_channels, num_classes)\n    elif network == 'unet3d_dtc':\n        net = unet3d_dtc(in_channels, num_classes)\n    elif network == 'vnet':\n        net = vnet(in_channels, num_classes)\n    elif network == 'vnet_cct':\n        net = vnet_cct(in_channels, num_classes)\n    elif network == 'vnet_dtc':\n        net = vnet_dtc(in_channels, num_classes)\n    elif network == 'resunet3d':\n        net = res_unet3d(in_channels, num_classes)\n    elif network == 'conresnet':\n        net = conresnet(in_channels, num_classes, img_shape=kwargs['img_shape'])\n    elif network == 'espnet3d':\n        net = espnet3d(in_channels, num_classes)\n    elif network == 'dmfnet':\n        net = dmfnet(in_channels, num_classes)\n    elif network == 'transbts':\n        net = transbts(in_channels, num_classes, img_shape=kwargs['img_shape'])\n    elif network == 'cotr':\n        net = cotr(in_channels, num_classes)\n    elif network == 'unertr':\n        net = unertr(in_channels, num_classes, img_shape=kwargs['img_shape'])\n    else:\n        print('the network you have entered is not supported yet')\n        sys.exit()\n    return net\n"
  },
  {
    "path": "models/networks_2d/__init__.py",
    "content": ""
  },
  {
    "path": "models/networks_2d/aerial_lanenet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch.distributions.uniform import Uniform\nimport numpy as np\n\n\nclass basic_block(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(basic_block, self).__init__()\n        self.block = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.ReLU(inplace=True))\n    def forward(self, x):\n        x = self.block(x)\n        return x\n\nclass Aerial_LaneNet(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(Aerial_LaneNet, self).__init__()\n\n        l1, l2, l3, l4, l5 = 64, 128, 256, 512, 512\n        dropout = 0.2\n\n        # e1\n        self.conv1_1 = basic_block(in_channels, l1)\n        self.conv1_2 = basic_block(l1, l1)\n        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # e2\n        self.conv2_1 = basic_block(l1+3, l2)\n        self.conv2_2 = basic_block(l2, l2)\n        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # e3\n        self.conv3_1 = basic_block(l2+3, l3)\n        self.conv3_2 = basic_block(l3, l3)\n        self.conv3_3 = basic_block(l3, l3)\n        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # e4\n        self.conv4_1 = basic_block(l3+3, l4)\n        self.conv4_2 = basic_block(l4, l4)\n        self.conv4_3 = basic_block(l4, l4)\n        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # e5\n        self.conv5_1 = basic_block(l4+3, l5)\n        self.conv5_2 = basic_block(l5, l5)\n        self.conv5_3 = basic_block(l5, l5)\n        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # e6\n        self.conv6_1 = basic_block(l5, 4096)\n        self.drop6_1 = nn.Dropout2d(dropout)\n        self.conv6_2 = basic_block(4096, 4096)\n        self.drop6_2 = nn.Dropout2d(dropout)\n        self.conv6_3 = nn.ConvTranspose2d(4096, l5, kernel_size=4, stride=2, padding=1, bias=False)\n\n        # d4\n        self.conv4_4 = basic_block(2*l5, l5)\n        self.drop4_4 = nn.Dropout2d(dropout)\n        self.conv4_5 = nn.ConvTranspose2d(l5, l3, kernel_size=4, stride=2, padding=1, bias=False)\n\n        # d3\n        self.conv3_4 = basic_block(2*l3, l3)\n        self.drop3_4 = nn.Dropout2d(dropout)\n        self.conv3_5 = nn.ConvTranspose2d(l3, l2, kernel_size=4, stride=2, padding=1, bias=False)\n\n        # d2\n        self.conv2_4 = basic_block(2*l2, l2)\n        self.drop2_4 = nn.Dropout2d(dropout)\n        self.conv2_5 = nn.ConvTranspose2d(l2, l1, kernel_size=4, stride=2, padding=1, bias=False)\n\n        # d1\n        self.conv1_3 = basic_block(2*l1, l1)\n        self.drop1_3 = nn.Dropout2d(dropout)\n        self.conv1_4 = nn.ConvTranspose2d(l1, num_classes, kernel_size=4, stride=2, padding=1, bias=False)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, x, x_wavelet_1, x_wavelet_2, x_wavelet_3, x_wavelet_4):\n\n        x1 = self.conv1_1(x)\n        x1 = self.conv1_2(x1)\n        x1 = self.pool1(x1)\n\n        x2 = torch.cat((x1, x_wavelet_1), dim=1)\n        x2 = self.conv2_1(x2)\n        x2 = self.conv2_2(x2)\n        x2 = self.pool2(x2)\n\n        x3 = torch.cat((x2, x_wavelet_2), dim=1)\n        x3 = self.conv3_1(x3)\n        x3 = self.conv3_2(x3)\n        x3 = self.conv3_3(x3)\n        x3 = self.pool3(x3)\n\n        x4 = torch.cat((x3, x_wavelet_3), dim=1)\n        x4 = self.conv4_1(x4)\n        x4 = self.conv4_2(x4)\n        x4 = self.conv4_3(x4)\n        x4 = self.pool4(x4)\n\n        x5 = torch.cat((x4, x_wavelet_4), dim=1)\n        x5 = self.conv5_1(x5)\n        x5 = self.conv5_2(x5)\n        x5 = self.conv5_3(x5)\n        x5 = self.pool5(x5)\n\n        x6 = self.conv6_1(x5)\n        x6 = self.drop6_1(x6)\n        x6 = self.conv6_2(x6)\n        x6 = self.drop6_2(x6)\n        x6 = self.conv6_3(x6)\n\n        x5 = torch.cat((x6, x4), dim=1)\n        x5 = self.conv4_4(x5)\n        x5 = self.drop4_4(x5)\n        x5 = self.conv4_5(x5)\n\n        x4 = torch.cat((x5, x3), dim=1)\n        x4 = self.conv3_4(x4)\n        x4 = self.drop3_4(x4)\n        x4 = self.conv3_5(x4)\n\n        x3 = torch.cat((x4, x2), dim=1)\n        x3 = self.conv2_4(x3)\n        x3 = self.drop2_4(x3)\n        x3 = self.conv2_5(x3)\n\n        x2 = torch.cat((x3, x1), dim=1)\n        x2 = self.conv1_3(x2)\n        x2 = self.drop1_3(x2)\n        x2 = self.conv1_4(x2)\n\n        return x2\n\n# if __name__ == '__main__':\n#     from loss.loss_function import segmentation_loss\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 128, 128).long()\n#     model = Aerial_LaneNet(1, 5)\n#     model.train()\n#     input1 = torch.rand(2, 1, 128, 128)\n#     input2 = torch.rand(2, 3, 64, 64)\n#     input3 = torch.rand(2, 3, 32, 32)\n#     input4 = torch.rand(2, 3, 16, 16)\n#     input5 = torch.rand(2, 3, 8, 8)\n#\n#     y = model(input1, input2, input3, input4, input5)\n#     loss_train = criterion(y, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(y.data.cpu().numpy().shape)\n#     print(loss_train)"
  },
  {
    "path": "models/networks_2d/hrnet.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport logging\nimport functools\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch._utils\nimport torch.nn.functional as F\nfrom torch.nn import init\n\ntry:\n    from .sync_bn.inplace_abn.bn import InPlaceABNSync\n    BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')\nexcept:\n    BatchNorm2d = nn.BatchNorm2d\n\nBN_MOMENTUM = 0.01\nlogger = logging.getLogger(__name__)\n\n\nmodel_urls = {\n    'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth',\n}\n\nimport sys\ntry:\n    from urllib import urlretrieve\nexcept ImportError:\n    from urllib.request import urlretrieve\n\n\ndef load_url(url, model_dir='./pretrained', map_location=None):\n    if not os.path.exists(model_dir):\n        os.makedirs(model_dir)\n    filename = url.split('/')[-1]\n    cached_file = os.path.join(model_dir, filename)\n    if not os.path.exists(cached_file):\n        sys.stderr.write('Downloading: \"{}\" to {}\\n'.format(url, cached_file))\n        urlretrieve(url, cached_file)\n    return torch.load(cached_file, map_location=map_location)\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,\n                               bias=False)\n        self.bn3 = BatchNorm2d(planes * self.expansion,\n                               momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\n\nclass HighResolutionModule(nn.Module):\n    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,\n                 num_channels, fuse_method, multi_scale_output=True):\n        super(HighResolutionModule, self).__init__()\n        self._check_branches(\n            num_branches, blocks, num_blocks, num_inchannels, num_channels)\n\n        self.num_inchannels = num_inchannels\n        self.fuse_method = fuse_method\n        self.num_branches = num_branches\n\n        self.multi_scale_output = multi_scale_output\n\n        self.branches = self._make_branches(\n            num_branches, blocks, num_blocks, num_channels)\n        self.fuse_layers = self._make_fuse_layers()\n        self.relu = nn.ReLU(inplace=False)\n\n    def _check_branches(self, num_branches, blocks, num_blocks,\n                        num_inchannels, num_channels):\n        if num_branches != len(num_blocks):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(\n                num_branches, len(num_blocks))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_channels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(\n                num_branches, len(num_channels))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n        if num_branches != len(num_inchannels):\n            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(\n                num_branches, len(num_inchannels))\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n\n    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,\n                         stride=1):\n        downsample = None\n        if stride != 1 or \\\n                self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.num_inchannels[branch_index],\n                          num_channels[branch_index] * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                BatchNorm2d(num_channels[branch_index] * block.expansion,\n                            momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(self.num_inchannels[branch_index],\n                            num_channels[branch_index], stride, downsample))\n        self.num_inchannels[branch_index] = \\\n            num_channels[branch_index] * block.expansion\n        for i in range(1, num_blocks[branch_index]):\n            layers.append(block(self.num_inchannels[branch_index],\n                                num_channels[branch_index]))\n\n        return nn.Sequential(*layers)\n\n    def _make_branches(self, num_branches, block, num_blocks, num_channels):\n        branches = []\n\n        for i in range(num_branches):\n            branches.append(\n                self._make_one_branch(i, block, num_blocks, num_channels))\n\n        return nn.ModuleList(branches)\n\n    def _make_fuse_layers(self):\n        if self.num_branches == 1:\n            return None\n\n        num_branches = self.num_branches\n        num_inchannels = self.num_inchannels\n        fuse_layers = []\n        for i in range(num_branches if self.multi_scale_output else 1):\n            fuse_layer = []\n            for j in range(num_branches):\n                if j > i:\n                    fuse_layer.append(nn.Sequential(\n                        nn.Conv2d(num_inchannels[j],\n                                  num_inchannels[i],\n                                  1,\n                                  1,\n                                  0,\n                                  bias=False),\n                        BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))\n                elif j == i:\n                    fuse_layer.append(None)\n                else:\n                    conv3x3s = []\n                    for k in range(i - j):\n                        if k == i - j - 1:\n                            num_outchannels_conv3x3 = num_inchannels[i]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                BatchNorm2d(num_outchannels_conv3x3,\n                                            momentum=BN_MOMENTUM)))\n                        else:\n                            num_outchannels_conv3x3 = num_inchannels[j]\n                            conv3x3s.append(nn.Sequential(\n                                nn.Conv2d(num_inchannels[j],\n                                          num_outchannels_conv3x3,\n                                          3, 2, 1, bias=False),\n                                BatchNorm2d(num_outchannels_conv3x3,\n                                            momentum=BN_MOMENTUM),\n                                nn.ReLU(inplace=False)))\n                    fuse_layer.append(nn.Sequential(*conv3x3s))\n            fuse_layers.append(nn.ModuleList(fuse_layer))\n\n        return nn.ModuleList(fuse_layers)\n\n    def get_num_inchannels(self):\n        return self.num_inchannels\n\n    def forward(self, x):\n        if self.num_branches == 1:\n            return [self.branches[0](x[0])]\n\n        for i in range(self.num_branches):\n            x[i] = self.branches[i](x[i])\n\n        x_fuse = []\n        for i in range(len(self.fuse_layers)):\n            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])\n            for j in range(1, self.num_branches):\n                if i == j:\n                    y = y + x[j]\n                elif j > i:\n                    width_output = x[i].shape[-1]\n                    height_output = x[i].shape[-2]\n                    y = y + F.interpolate(\n                        self.fuse_layers[i][j](x[j]),\n                        size=[height_output, width_output],\n                        mode='bilinear',align_corners=False)\n                else:\n                    y = y + self.fuse_layers[i][j](x[j])\n            x_fuse.append(self.relu(y))\n\n        return x_fuse\n\n\nblocks_dict = {\n    'BASIC': BasicBlock,\n    'BOTTLENECK': Bottleneck\n}\n\n\nclass HighResolutionNet(nn.Module):\n\n    def __init__(self, in_channels, extra, num_classes,**kwargs):\n        super(HighResolutionNet, self).__init__()\n\n        # stem net\n        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=False)\n\n        self.stage1_cfg = extra['STAGE1']\n        num_channels = self.stage1_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage1_cfg['BLOCK']]\n        num_blocks = self.stage1_cfg['NUM_BLOCKS']\n        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)\n        stage1_out_channel = block.expansion * num_channels\n\n        self.stage2_cfg = extra['STAGE2']\n        num_channels = self.stage2_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage2_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition1 = self._make_transition_layer(\n            [stage1_out_channel], num_channels)\n        self.stage2, pre_stage_channels = self._make_stage(\n            self.stage2_cfg, num_channels)\n\n        self.stage3_cfg = extra['STAGE3']\n        num_channels = self.stage3_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage3_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition2 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage3, pre_stage_channels = self._make_stage(\n            self.stage3_cfg, num_channels)\n\n        self.stage4_cfg = extra['STAGE4']\n        num_channels = self.stage4_cfg['NUM_CHANNELS']\n        block = blocks_dict[self.stage4_cfg['BLOCK']]\n        num_channels = [\n            num_channels[i] * block.expansion for i in range(len(num_channels))]\n        self.transition3 = self._make_transition_layer(\n            pre_stage_channels, num_channels)\n        self.stage4, pre_stage_channels = self._make_stage(\n            self.stage4_cfg, num_channels, multi_scale_output=True)\n\n        last_inp_channels = int(np.sum(pre_stage_channels))\n\n        self.last_layer = nn.Sequential(\n            nn.Conv2d(\n                in_channels=last_inp_channels,\n                out_channels=last_inp_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0),\n            BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=False),\n            nn.Conv2d(\n                in_channels=last_inp_channels,\n                out_channels=num_classes,\n                kernel_size=extra['FINAL_CONV_KERNEL'],\n                stride=1,\n                padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)\n        )\n\n    def _make_transition_layer(\n            self, num_channels_pre_layer, num_channels_cur_layer):\n        num_branches_cur = len(num_channels_cur_layer)\n        num_branches_pre = len(num_channels_pre_layer)\n\n        transition_layers = []\n        for i in range(num_branches_cur):\n            if i < num_branches_pre:\n                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:\n                    transition_layers.append(nn.Sequential(\n                        nn.Conv2d(num_channels_pre_layer[i],\n                                  num_channels_cur_layer[i],\n                                  3,\n                                  1,\n                                  1,\n                                  bias=False),\n                        BatchNorm2d(\n                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                else:\n                    transition_layers.append(None)\n            else:\n                conv3x3s = []\n                for j in range(i + 1 - num_branches_pre):\n                    inchannels = num_channels_pre_layer[-1]\n                    outchannels = num_channels_cur_layer[i] \\\n                        if j == i - num_branches_pre else inchannels\n                    conv3x3s.append(nn.Sequential(\n                        nn.Conv2d(\n                            inchannels, outchannels, 3, 2, 1, bias=False),\n                        BatchNorm2d(outchannels, momentum=BN_MOMENTUM),\n                        nn.ReLU(inplace=False)))\n                transition_layers.append(nn.Sequential(*conv3x3s))\n\n        return nn.ModuleList(transition_layers)\n\n    def _make_layer(self, block, inplanes, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),\n            )\n\n        layers = []\n        layers.append(block(inplanes, planes, stride, downsample))\n        inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def _make_stage(self, layer_config, num_inchannels,\n                    multi_scale_output=True):\n        num_modules = layer_config['NUM_MODULES']\n        num_branches = layer_config['NUM_BRANCHES']\n        num_blocks = layer_config['NUM_BLOCKS']\n        num_channels = layer_config['NUM_CHANNELS']\n        block = blocks_dict[layer_config['BLOCK']]\n        fuse_method = layer_config['FUSE_METHOD']\n\n        modules = []\n        for i in range(num_modules):\n            # multi_scale_output is only used last module\n            if not multi_scale_output and i == num_modules - 1:\n                reset_multi_scale_output = False\n            else:\n                reset_multi_scale_output = True\n            modules.append(\n                HighResolutionModule(num_branches,\n                                     block,\n                                     num_blocks,\n                                     num_inchannels,\n                                     num_channels,\n                                     fuse_method,\n                                     reset_multi_scale_output)\n            )\n            num_inchannels = modules[-1].get_num_inchannels()\n\n        return nn.Sequential(*modules), num_inchannels\n\n    def forward(self, x):\n        size =x.shape[2:]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.layer1(x)\n\n        x_list = []\n        for i in range(self.stage2_cfg['NUM_BRANCHES']):\n            if self.transition1[i] is not None:\n                x_list.append(self.transition1[i](x))\n            else:\n                x_list.append(x)\n        y_list = self.stage2(x_list)\n\n        x_list = []\n        for i in range(self.stage3_cfg['NUM_BRANCHES']):\n            if self.transition2[i] is not None:\n                if i < self.stage2_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition2[i](y_list[i]))\n                else:\n                    x_list.append(self.transition2[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        y_list = self.stage3(x_list)\n\n        x_list = []\n        for i in range(self.stage4_cfg['NUM_BRANCHES']):\n            if self.transition3[i] is not None:\n                if i < self.stage3_cfg['NUM_BRANCHES']:\n                    x_list.append(self.transition3[i](y_list[i]))\n                else:\n                    x_list.append(self.transition3[i](y_list[-1]))\n            else:\n                x_list.append(y_list[i])\n        x = self.stage4(x_list)\n\n        # Upsampling\n        x0_h, x0_w = x[0].size(2), x[0].size(3)\n        x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False)\n        x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False)\n        x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False)\n\n        x = torch.cat([x[0], x1, x2, x3], 1)\n\n        x = self.last_layer(x)\n        x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)\n        # outputs = []\n        # outputs.append(x)\n        # return outputs\n        return x\n    def init_weights(self, pretrained='', ):\n        logger.info('=> init weights from normal distribution')\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, std=0.001)\n            elif isinstance(m, InPlaceABNSync):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n        if os.path.isfile(pretrained):\n            pretrained_dict = torch.load(pretrained)\n            logger.info('=> loading pretrained model {}'.format(pretrained))\n            model_dict = self.state_dict()\n            pretrained_dict = {k: v for k, v in pretrained_dict.items()\n                               if k in model_dict.keys()}\n            for k, _ in pretrained_dict.items():\n                logger.info(\n                    '=> loading {} pretrained model {}'.format(k, pretrained))\n            model_dict.update(pretrained_dict)\n            self.load_state_dict(model_dict)\n\n# class HRNet(nn.Module):\n#     def __init__(self, in_channels, extra, num_classes, **kwargs):\n#         super(HRNet, self).__init__()\n#         self.branch1 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)\n#         self.branch2 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)\n#\n#     def forward(self, data, step=1):\n#         if not self.training:\n#             pred1 = self.branch1(data)\n#             return pred1\n#\n#         if step == 1:\n#             return self.branch1(data)\n#         elif step == 2:\n#             return self.branch2(data)\n\nextra_18 = {\n            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},\n            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (18, 36), 'FUSE_METHOD': 'SUM'},\n            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (18, 36, 72), 'FUSE_METHOD': 'SUM'},\n            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (18, 36, 72, 144), 'FUSE_METHOD': 'SUM'},\n            'FINAL_CONV_KERNEL': 1\n            }\n\nextra_32 = {\n            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},\n            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (32, 64), 'FUSE_METHOD': 'SUM'},\n            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (32, 64, 128), 'FUSE_METHOD': 'SUM'},\n            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (32, 64, 128, 256), 'FUSE_METHOD': 'SUM'},\n            'FINAL_CONV_KERNEL': 1\n            }\n\nextra_48 = {\n            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},\n            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'},\n            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'},\n            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'},\n            'FINAL_CONV_KERNEL': 1\n            }\n\nextra_64 = {\n            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},\n            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (64, 128), 'FUSE_METHOD': 'SUM'},\n            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (64, 128, 256), 'FUSE_METHOD': 'SUM'},\n            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (64, 128, 256, 512), 'FUSE_METHOD': 'SUM'},\n            'FINAL_CONV_KERNEL': 1\n            }\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n# def hrnet18(in_channels, num_classes):\n#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)\n#     return model\n#\n# def hrnet32(in_channels, num_classes):\n#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)\n#     return model\n#\n# def hrnet48(in_channels, num_classes):\n#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)\n#     return model\n#\n# def hrnet64(in_channels, num_classes):\n#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)\n#     return model\n\ndef hrnet18(in_channels, num_classes):\n    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)\n    init_weights(model, 'kaiming')\n    return model\n\ndef hrnet32(in_channels, num_classes):\n    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)\n    init_weights(model, 'kaiming')\n    return model\n\ndef hrnet48(in_channels, num_classes):\n    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)\n    init_weights(model, 'kaiming')\n    return model\n\ndef hrnet64(in_channels, num_classes):\n    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = hrnet48(1,10)\n\n    # total = sum([param.nelement() for param in model.parameters()])\n    # from thop import profile,clever_format\n    #\n    # input = torch.randn(1, 1, 128, 128)\n    # flops, params = profile(model, inputs=(input, ))\n    # macs, params = clever_format([flops, params], \"%.3f\")\n    # print(macs)\n    # print(params)\n    # print(total)\n    # model.eval()\n    # input = torch.rand(1,1,256,256)\n    # output = model(input)\n    # output = output[0].data.cpu().numpy()\n    # print(output)\n    # print(output.shape)\n\n"
  },
  {
    "path": "models/networks_2d/mwcnn.py",
    "content": "import torch\nimport torch.nn as nn\nimport scipy.io as sio\nimport math\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\n\ndef default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size // 2) + dilation - 1, bias=bias, dilation=dilation)\n\n\ndef default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size // 2), bias=bias, groups=groups)\n\n\n# def shuffle_channel()\n\ndef channel_shuffle(x, groups):\n    batchsize, num_channels, height, width = x.size()\n\n    channels_per_group = num_channels // groups\n\n    # reshape\n    x = x.view(batchsize, groups,\n               channels_per_group, height, width)\n\n    x = torch.transpose(x, 1, 2).contiguous()\n\n    # flatten\n    x = x.view(batchsize, -1, height, width)\n\n    return x\n\n\ndef pixel_down_shuffle(x, downsacale_factor):\n    batchsize, num_channels, height, width = x.size()\n\n    out_height = height // downsacale_factor\n    out_width = width // downsacale_factor\n    input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width,\n                                     downsacale_factor)\n\n    num_channels *= downsacale_factor ** 2\n    unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()\n\n    return unshuffle_out.view(batchsize, num_channels, out_height, out_width)\n\n\ndef sp_init(x):\n    x01 = x[:, :, 0::2, :]\n    x02 = x[:, :, 1::2, :]\n    x_LL = x01[:, :, :, 0::2]\n    x_HL = x02[:, :, :, 0::2]\n    x_LH = x01[:, :, :, 1::2]\n    x_HH = x02[:, :, :, 1::2]\n\n    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)\n\n\ndef dwt_init(x):\n    x01 = x[:, :, 0::2, :] / 2\n    x02 = x[:, :, 1::2, :] / 2\n    x1 = x01[:, :, :, 0::2]\n    x2 = x02[:, :, :, 0::2]\n    x3 = x01[:, :, :, 1::2]\n    x4 = x02[:, :, :, 1::2]\n    x_LL = x1 + x2 + x3 + x4\n    x_HL = -x1 - x2 + x3 + x4\n    x_LH = -x1 + x2 - x3 + x4\n    x_HH = x1 - x2 - x3 + x4\n\n    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)\n\n\ndef iwt_init(x):\n    r = 2\n    in_batch, in_channel, in_height, in_width = x.size()\n    # print([in_batch, in_channel, in_height, in_width])\n    out_batch, out_channel, out_height, out_width = in_batch, int(\n        in_channel / (r ** 2)), r * in_height, r * in_width\n    x1 = x[:, 0:out_channel, :, :] / 2\n    x2 = x[:, out_channel:out_channel * 2, :, :] / 2\n    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2\n    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2\n\n    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()\n    # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()\n\n    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4\n    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4\n    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4\n    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4\n\n    return h\n\n\nclass Channel_Shuffle(nn.Module):\n    def __init__(self, conv_groups):\n        super(Channel_Shuffle, self).__init__()\n        self.conv_groups = conv_groups\n        self.requires_grad = False\n\n    def forward(self, x):\n        return channel_shuffle(x, self.conv_groups)\n\n\nclass SP(nn.Module):\n    def __init__(self):\n        super(SP, self).__init__()\n        self.requires_grad = False\n\n    def forward(self, x):\n        return sp_init(x)\n\n\nclass Pixel_Down_Shuffle(nn.Module):\n    def __init__(self):\n        super(Pixel_Down_Shuffle, self).__init__()\n        self.requires_grad = False\n\n    def forward(self, x):\n        return pixel_down_shuffle(x, 2)\n\n\nclass DWT(nn.Module):\n    def __init__(self):\n        super(DWT, self).__init__()\n        self.requires_grad = False\n\n    def forward(self, x):\n        return dwt_init(x)\n\n\nclass IWT(nn.Module):\n    def __init__(self):\n        super(IWT, self).__init__()\n        self.requires_grad = False\n\n    def forward(self, x):\n        return iwt_init(x)\n\n\nclass MeanShift(nn.Conv2d):\n    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1)\n        self.weight.data.div_(std.view(3, 1, 1, 1))\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)\n        self.bias.data.div_(std)\n        self.requires_grad = False\n        if sign == -1:\n            self.create_graph = False\n            self.volatile = True\n\n\nclass MeanShift2(nn.Conv2d):\n    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):\n        super(MeanShift2, self).__init__(4, 4, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(4).view(4, 4, 1, 1)\n        self.weight.data.div_(std.view(4, 1, 1, 1))\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)\n        self.bias.data.div_(std)\n        self.requires_grad = False\n        if sign == -1:\n            self.volatile = True\n\n\nclass BasicBlock(nn.Sequential):\n    def __init__(\n            self, in_channels, out_channels, kernel_size, stride=1, bias=False,\n            bn=False, act=nn.ReLU(True)):\n\n        m = [nn.Conv2d(\n            in_channels, out_channels, kernel_size,\n            padding=(kernel_size // 2), stride=stride, bias=bias)\n        ]\n        if bn: m.append(nn.BatchNorm2d(out_channels))\n        if act is not None: m.append(act)\n        super(BasicBlock, self).__init__(*m)\n\n\nclass BBlock(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n        super(BBlock, self).__init__()\n        m = []\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x).mul(self.res_scale)\n        return x\n\n\nclass DBlock_com(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_com, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass DBlock_inv(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_inv, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass DBlock_com1(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_com1, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass DBlock_inv1(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_inv1, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass DBlock_com2(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_com2, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass DBlock_inv2(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DBlock_inv2, self).__init__()\n        m = []\n\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))\n        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x)\n        return x\n\n\nclass ShuffleBlock(nn.Module):\n    def __init__(\n            self, conv, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1, conv_groups=1):\n        super(ShuffleBlock, self).__init__()\n        m = []\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))\n        m.append(Channel_Shuffle(conv_groups))\n        if bn: m.append(nn.BatchNorm2d(out_channels))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x).mul(self.res_scale)\n        return x\n\n\nclass DWBlock(nn.Module):\n    def __init__(\n            self, conv, conv1, in_channels, out_channels, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(DWBlock, self).__init__()\n        m = []\n        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))\n        if bn: m.append(nn.BatchNorm2d(out_channels))\n        m.append(act)\n\n        m.append(conv1(in_channels, out_channels, 1, bias=bias))\n        if bn: m.append(nn.BatchNorm2d(out_channels))\n        m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        x = self.body(x).mul(self.res_scale)\n        return x\n\n\nclass ResBlock(nn.Module):\n    def __init__(\n            self, conv, n_feat, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: m.append(nn.BatchNorm2d(n_feat))\n            if i == 0: m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\n\nclass Block(nn.Module):\n    def __init__(\n            self, conv, n_feat, kernel_size,\n            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(Block, self).__init__()\n        m = []\n        for i in range(4):\n            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: m.append(nn.BatchNorm2d(n_feat))\n            if i == 0: m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        # res += x\n\n        return res\n\n\nclass Upsampler(nn.Sequential):\n    def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):\n\n        m = []\n        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(conv(n_feat, 4 * n_feat, 3, bias))\n                m.append(nn.PixelShuffle(2))\n                if bn: m.append(nn.BatchNorm2d(n_feat))\n                if act: m.append(act())\n        elif scale == 3:\n            m.append(conv(n_feat, 9 * n_feat, 3, bias))\n            m.append(nn.PixelShuffle(3))\n            if bn: m.append(nn.BatchNorm2d(n_feat))\n            if act: m.append(act())\n        else:\n            raise NotImplementedError\n\n        super(Upsampler, self).__init__(*m)\n\n\nclass MWCNN(nn.Module):\n    def __init__(self, in_channels, num_classes, conv=default_conv):\n        super(MWCNN, self).__init__()\n        kernel_size = 3\n        self.scale_idx = 0\n        n_feats = 64\n        act = nn.ReLU(True)\n\n        self.DWT = DWT()\n        self.IWT = IWT()\n\n        n = 1\n        m_head = [BBlock(conv, in_channels, n_feats, kernel_size, act=act)]\n        d_l0 = []\n        d_l0.append(DBlock_com1(conv, n_feats, n_feats, kernel_size, act=act, bn=False))\n\n\n        d_l1 = [BBlock(conv, n_feats * 4, n_feats * 2, kernel_size, act=act, bn=False)]\n        d_l1.append(DBlock_com1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False))\n\n        d_l2 = []\n        d_l2.append(BBlock(conv, n_feats * 8, n_feats * 4, kernel_size, act=act, bn=False))\n        d_l2.append(DBlock_com1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False))\n        pro_l3 = []\n        pro_l3.append(BBlock(conv, n_feats * 16, n_feats * 8, kernel_size, act=act, bn=False))\n        pro_l3.append(DBlock_com(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))\n        pro_l3.append(DBlock_inv(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))\n        pro_l3.append(BBlock(conv, n_feats * 8, n_feats * 16, kernel_size, act=act, bn=False))\n\n        i_l2 = [DBlock_inv1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False)]\n        i_l2.append(BBlock(conv, n_feats * 4, n_feats * 8, kernel_size, act=act, bn=False))\n\n        i_l1 = [DBlock_inv1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False)]\n        i_l1.append(BBlock(conv, n_feats * 2, n_feats * 4, kernel_size, act=act, bn=False))\n\n        i_l0 = [DBlock_inv1(conv, n_feats, n_feats, kernel_size, act=act, bn=False)]\n\n        m_tail = [conv(n_feats, num_classes, kernel_size)]\n\n        self.head = nn.Sequential(*m_head)\n        self.d_l2 = nn.Sequential(*d_l2)\n        self.d_l1 = nn.Sequential(*d_l1)\n        self.d_l0 = nn.Sequential(*d_l0)\n        self.pro_l3 = nn.Sequential(*pro_l3)\n        self.i_l2 = nn.Sequential(*i_l2)\n        self.i_l1 = nn.Sequential(*i_l1)\n        self.i_l0 = nn.Sequential(*i_l0)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x0 = self.d_l0(self.head(x))\n        x1 = self.d_l1(self.DWT(x0))\n        x2 = self.d_l2(self.DWT(x1))\n        x_ = self.IWT(self.pro_l3(self.DWT(x2))) + x2\n        x_ = self.IWT(self.i_l2(x_)) + x1\n        x_ = self.IWT(self.i_l1(x_)) + x0\n        x_ = self.tail(self.i_l0(x_))\n\n        return x_\n\n    def set_scale(self, scale_idx):\n        self.scale_idx = scale_idx\n\n\ndef mwcnn(in_channels, num_classes):\n    model = MWCNN(in_channels, num_classes)\n    return model\n\n# if __name__ == '__main__':\n#     from loss.loss_function import segmentation_loss\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 128, 128).long()\n#     model = mwcnn(1, 2)\n#     model.train()\n#     input1 = torch.rand(2, 1, 128, 128)\n#     y = model(input1)\n#     loss_train = criterion(y, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(y.data.cpu().numpy().shape)\n#     print(loss_train)"
  },
  {
    "path": "models/networks_2d/resunet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass ResidualConv(nn.Module):\n    def __init__(self, input_dim, output_dim, stride, padding):\n        super(ResidualConv, self).__init__()\n\n        self.conv_block = nn.Sequential(\n            nn.BatchNorm2d(input_dim),\n            nn.ReLU(),\n            nn.Conv2d(\n                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding\n            ),\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n        )\n        self.conv_skip = nn.Sequential(\n            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n            nn.BatchNorm2d(output_dim),\n        )\n\n    def forward(self, x):\n\n        return self.conv_block(x) + self.conv_skip(x)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, input_dim, output_dim, kernel, stride):\n        super(Upsample, self).__init__()\n\n        self.upsample = nn.ConvTranspose2d(\n            input_dim, output_dim, kernel_size=kernel, stride=stride\n        )\n\n    def forward(self, x):\n        return self.upsample(x)\n\n\nclass ResUnet(nn.Module):\n    def __init__(self, in_channels, num_classes, filters=[64, 128, 256, 512]):\n        super(ResUnet, self).__init__()\n\n        self.input_layer = nn.Sequential(\n            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),\n            nn.BatchNorm2d(filters[0]),\n            nn.ReLU(),\n            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),\n        )\n        self.input_skip = nn.Sequential(\n            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)\n        )\n\n        self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)\n        self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)\n\n        self.bridge = ResidualConv(filters[2], filters[3], 2, 1)\n\n        self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)\n        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)\n\n        self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)\n        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)\n\n        self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)\n        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)\n\n        self.output_layer = nn.Conv2d(filters[0], num_classes, 1, 1)\n\n    def forward(self, x):\n        # Encode\n        x1 = self.input_layer(x) + self.input_skip(x)\n        x2 = self.residual_conv_1(x1)\n        x3 = self.residual_conv_2(x2)\n        # Bridge\n        x4 = self.bridge(x3)\n        # Decode\n        x4 = self.upsample_1(x4)\n        x5 = torch.cat([x4, x3], dim=1)\n\n        x6 = self.up_residual_conv1(x5)\n\n        x6 = self.upsample_2(x6)\n        x7 = torch.cat([x6, x2], dim=1)\n\n        x8 = self.up_residual_conv2(x7)\n\n        x8 = self.upsample_3(x8)\n        x9 = torch.cat([x8, x1], dim=1)\n\n        x10 = self.up_residual_conv3(x9)\n\n        output = self.output_layer(x10)\n\n        return output\n\n\ndef res_unet(in_channels, num_classes):\n    model = ResUnet(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = res_unet(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/resunet_plusplus.py",
    "content": "import torch.nn as nn\nimport torch\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass ResidualConv(nn.Module):\n    def __init__(self, input_dim, output_dim, stride, padding):\n        super(ResidualConv, self).__init__()\n\n        self.conv_block = nn.Sequential(\n            nn.BatchNorm2d(input_dim),\n            nn.ReLU(),\n            nn.Conv2d(\n                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding\n            ),\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),\n        )\n        self.conv_skip = nn.Sequential(\n            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),\n            nn.BatchNorm2d(output_dim),\n        )\n\n    def forward(self, x):\n\n        return self.conv_block(x) + self.conv_skip(x)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, input_dim, output_dim, kernel, stride):\n        super(Upsample, self).__init__()\n\n        self.upsample = nn.ConvTranspose2d(\n            input_dim, output_dim, kernel_size=kernel, stride=stride\n        )\n\n    def forward(self, x):\n        return self.upsample(x)\n\nclass Squeeze_Excite_Block(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(Squeeze_Excite_Block, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Sequential(\n            nn.Linear(channel, channel // reduction, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Linear(channel // reduction, channel, bias=False),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, x):\n        b, c, _, _ = x.size()\n        y = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view(b, c, 1, 1)\n        return x * y.expand_as(x)\n\nclass ASPP(nn.Module):\n    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):\n        super(ASPP, self).__init__()\n\n        self.aspp_block1 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n        self.aspp_block2 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n        self.aspp_block3 = nn.Sequential(\n            nn.Conv2d(\n                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]\n            ),\n            nn.ReLU(inplace=True),\n            nn.BatchNorm2d(out_dims),\n        )\n\n        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)\n        self._init_weights()\n\n    def forward(self, x):\n        x1 = self.aspp_block1(x)\n        x2 = self.aspp_block2(x)\n        x3 = self.aspp_block3(x)\n        out = torch.cat([x1, x2, x3], dim=1)\n        return self.output(out)\n\n    def _init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\nclass Upsample_(nn.Module):\n    def __init__(self, scale=2):\n        super(Upsample_, self).__init__()\n\n        self.upsample = nn.Upsample(mode=\"bilinear\", scale_factor=scale, align_corners=True)\n\n    def forward(self, x):\n        return self.upsample(x)\n\nclass AttentionBlock(nn.Module):\n    def __init__(self, input_encoder, input_decoder, output_dim):\n        super(AttentionBlock, self).__init__()\n\n        self.conv_encoder = nn.Sequential(\n            nn.BatchNorm2d(input_encoder),\n            nn.ReLU(),\n            nn.Conv2d(input_encoder, output_dim, 3, padding=1),\n            nn.MaxPool2d(2, 2),\n        )\n\n        self.conv_decoder = nn.Sequential(\n            nn.BatchNorm2d(input_decoder),\n            nn.ReLU(),\n            nn.Conv2d(input_decoder, output_dim, 3, padding=1),\n        )\n\n        self.conv_attn = nn.Sequential(\n            nn.BatchNorm2d(output_dim),\n            nn.ReLU(),\n            nn.Conv2d(output_dim, 1, 1),\n        )\n\n    def forward(self, x1, x2):\n        out = self.conv_encoder(x1) + self.conv_decoder(x2)\n        out = self.conv_attn(out)\n        return out * x2\n\n\n\nclass ResUnetPlusPlus(nn.Module):\n    def __init__(self, in_channels, num_classes, filters=[32, 64, 128, 256, 512]):\n        super(ResUnetPlusPlus, self).__init__()\n\n        self.input_layer = nn.Sequential(\n            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),\n            nn.BatchNorm2d(filters[0]),\n            nn.ReLU(),\n            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),\n        )\n        self.input_skip = nn.Sequential(\n            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)\n        )\n\n        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])\n\n        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)\n\n        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])\n\n        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)\n\n        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])\n\n        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)\n\n        self.aspp_bridge = ASPP(filters[3], filters[4])\n\n        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])\n        self.upsample1 = Upsample_(2)\n        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)\n\n        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])\n        self.upsample2 = Upsample_(2)\n        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)\n\n        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])\n        self.upsample3 = Upsample_(2)\n        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)\n\n        self.aspp_out = ASPP(filters[1], filters[0])\n\n        self.output_layer = nn.Conv2d(filters[0], num_classes, 1)\n\n    def forward(self, x):\n        x1 = self.input_layer(x) + self.input_skip(x)\n\n        x2 = self.squeeze_excite1(x1)\n        x2 = self.residual_conv1(x2)\n\n        x3 = self.squeeze_excite2(x2)\n        x3 = self.residual_conv2(x3)\n\n        x4 = self.squeeze_excite3(x3)\n        x4 = self.residual_conv3(x4)\n\n        x5 = self.aspp_bridge(x4)\n\n        x6 = self.attn1(x3, x5)\n        x6 = self.upsample1(x6)\n        x6 = torch.cat([x6, x3], dim=1)\n        x6 = self.up_residual_conv1(x6)\n\n        x7 = self.attn2(x2, x6)\n        x7 = self.upsample2(x7)\n        x7 = torch.cat([x7, x2], dim=1)\n        x7 = self.up_residual_conv2(x7)\n\n        x8 = self.attn3(x1, x7)\n        x8 = self.upsample3(x8)\n        x8 = torch.cat([x8, x1], dim=1)\n        x8 = self.up_residual_conv3(x8)\n\n        x9 = self.aspp_out(x8)\n        out = self.output_layer(x9)\n\n        return out\n\ndef res_unet_plusplus(in_channels, num_classes):\n    model = ResUnetPlusPlus(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = res_unet_plusplus(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/swinunet.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport copy\nimport logging\nimport math\n\nfrom os.path import join as pjoin\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm\nfrom torch.nn.modules.utils import _pair\nfrom scipy import ndimage\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nfrom einops import rearrange\nimport torch.utils.checkpoint as checkpoint\n\n\nlogger = logging.getLogger(__name__)\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Args:\n        x: (B, H, W, C)\n        window_size (int): window size\n    Returns:\n        windows: (num_windows*B, window_size, window_size, C)\n    \"\"\"\n    B, H, W, C = x.shape\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(\n    ).view(-1, window_size, window_size, C)\n    return windows\n\n\ndef window_reverse(windows, window_size, H, W):\n    \"\"\"\n    Args:\n        windows: (num_windows*B, window_size, window_size, C)\n        window_size (int): Window size\n        H (int): Height of image\n        W (int): Width of image\n    Returns:\n        x: (B, H, W, C)\n    \"\"\"\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n    x = windows.view(B, H // window_size, W // window_size,\n                     window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n    return x\n\n\nclass WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple[int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size  # Wh, Ww\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        # define a parameter table of relative position bias\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = coords_flatten[:, :, None] - \\\n            coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(\n            1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - \\\n            1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n        self.register_buffer(\"relative_position_index\",\n                             relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4).contiguous()\n        # make torchscript happy (cannot use tensor as tuple)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        q = q * self.scale\n        attn = (q @ k.transpose(-2, -1).contiguous())\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n        relative_position_bias = relative_position_bias.permute(\n            2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        if mask is not None:\n            nW = mask.shape[0]\n            attn = attn.view(B_ // nW, nW, self.num_heads, N,\n                             N) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).contiguous().reshape(B_, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        # calculate flops for 1 window with token length of N\n        flops = 0\n        # qkv = self.qkv(x)\n        flops += N * self.dim * 3 * self.dim\n        # attn = (q @ k.transpose(-2, -1))\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n        #  x = (attn @ v)\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n        # x = self.proj(x)\n        flops += N * self.dim * self.dim\n        return flops\n\n\nclass SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) <= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n        if self.shift_size > 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask[:, h, w, :] = cnt\n                    cnt += 1\n\n            # nW, window_size, window_size, 1\n            mask_windows = window_partition(img_mask, self.window_size)\n            mask_windows = mask_windows.view(-1,\n                                             self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(\n                attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        if self.shift_size > 0:\n            shifted_x = torch.roll(\n                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        # nW*B, window_size, window_size, C\n        x_windows = window_partition(shifted_x, self.window_size)\n        # nW*B, window_size*window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)\n\n        # W-MSA/SW-MSA\n        # nW*B, window_size*window_size, C\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)\n\n        # merge windows\n        attn_windows = attn_windows.view(-1,\n                                         self.window_size, self.window_size, C)\n        shifted_x = window_reverse(\n            attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size > 0:\n            x = torch.roll(shifted_x, shifts=(\n                self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA/SW-MSA\n        nW = H * W / self.window_size / self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops\n\n\nclass PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n    Args:\n        input_resolution (tuple[int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n    def flops(self):\n        H, W = self.input_resolution\n        flops = H * W * self.dim\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n        return flops\n\n\nclass PatchExpand(nn.Module):\n    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.expand = nn.Linear(\n            dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()\n        self.norm = norm_layer(dim // dim_scale)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        x = self.expand(x)\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C)\n        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',\n                      p1=2, p2=2, c=C//4)\n        x = x.view(B, -1, C//4)\n        x = self.norm(x)\n\n        return x\n\n\nclass FinalPatchExpand_X4(nn.Module):\n    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.dim_scale = dim_scale\n        self.expand = nn.Linear(dim, 16*dim, bias=False)\n        self.output_dim = dim\n        self.norm = norm_layer(self.output_dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        x = self.expand(x)\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        x = x.view(B, H, W, C)\n        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',\n                      p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))\n        x = x.view(B, -1, self.output_dim)\n        x = self.norm(x)\n\n        return x\n\nclass BasicLayer(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (\n                                     i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(\n                                     drop_path, list) else drop_path,\n                                 norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if downsample is not None:\n            self.downsample = downsample(\n                input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -> str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops\n\n\nclass BasicLayer_up(nn.Module):\n    \"\"\" A basic Swin Transformer layer for one stage.\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple[int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        # build blocks\n        self.blocks = nn.ModuleList([\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (\n                                     i % 2 == 0) else window_size // 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path[i] if isinstance(\n                                     drop_path, list) else drop_path,\n                                 norm_layer=norm_layer)\n            for i in range(depth)])\n\n        # patch merging layer\n        if upsample is not None:\n            self.upsample = PatchExpand(\n                input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)\n        else:\n            self.upsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.upsample is not None:\n            x = self.upsample(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = [img_size[0] //\n                              patch_size[0], img_size[1] // patch_size[1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim,\n                              kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2).contiguous()  # B Ph*Pw C\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * \\\n            (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwinTransformerSys(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https://arxiv.org/pdf/2103.14030\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, final_upsample=\"expand_first\", **kwargs):\n        super().__init__()\n\n        print(\"SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}\".format(depths,\n                                                                                                                         depths_decoder, drop_path_rate, num_classes))\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.num_features_up = int(embed_dim * 2)\n        self.mlp_ratio = mlp_ratio\n        self.final_upsample = final_upsample\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(\n                torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate,\n                                                sum(depths))]  # stochastic depth decay rule\n\n        # build encoder and bottleneck layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n                                                 patches_resolution[1] // (2 ** i_layer)),\n                               depth=depths[i_layer],\n                               num_heads=num_heads[i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr[sum(depths[:i_layer]):sum(\n                                   depths[:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (\n                                   i_layer < self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        # build decoder layers\n        self.layers_up = nn.ModuleList()\n        self.concat_back_dim = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),\n                                      int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()\n            if i_layer == 0:\n                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),\n                                                         patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)\n            else:\n                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),\n                                         input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),\n                                                           patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),\n                                         depth=depths[(\n                                             self.num_layers-1-i_layer)],\n                                         num_heads=num_heads[(\n                                             self.num_layers-1-i_layer)],\n                                         window_size=window_size,\n                                         mlp_ratio=self.mlp_ratio,\n                                         qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                         drop=drop_rate, attn_drop=attn_drop_rate,\n                                         drop_path=dpr[sum(depths[:(\n                                             self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],\n                                         norm_layer=norm_layer,\n                                         upsample=PatchExpand if (\n                                             i_layer < self.num_layers - 1) else None,\n                                         use_checkpoint=use_checkpoint)\n            self.layers_up.append(layer_up)\n            self.concat_back_dim.append(concat_linear)\n\n        self.norm = norm_layer(self.num_features)\n        self.norm_up = norm_layer(self.embed_dim)\n\n        if self.final_upsample == \"expand_first\":\n            print(\"---final upsample expand_first---\")\n            self.up = FinalPatchExpand_X4(input_resolution=(\n                img_size//patch_size, img_size//patch_size), dim_scale=4, dim=embed_dim)\n            self.output = nn.Conv2d(\n                in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    #Encoder and Bottleneck\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n        x_downsample = []\n\n        for layer in self.layers:\n            x_downsample.append(x)\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n\n        return x, x_downsample\n\n    # Dencoder and Skip connection\n    def forward_up_features(self, x, x_downsample):\n        for inx, layer_up in enumerate(self.layers_up):\n            if inx == 0:\n                x = layer_up(x)\n            else:\n                x = torch.cat([x, x_downsample[3-inx]], -1)\n                x = self.concat_back_dim[inx](x)\n                x = layer_up(x)\n\n        x = self.norm_up(x)  # B L C\n\n        return x\n\n    def up_x4(self, x):\n        H, W = self.patches_resolution\n        B, L, C = x.shape\n        assert L == H*W, \"input features has wrong size\"\n\n        if self.final_upsample == \"expand_first\":\n            x = self.up(x)\n            x = x.view(B, 4*H, 4*W, -1)\n            x = x.permute(0, 3, 1, 2).contiguous()  # B,C,H,W\n            x = self.output(x)\n\n        return x\n\n    def forward(self, x):\n        x, x_downsample = self.forward_features(x)\n        x = self.forward_up_features(x, x_downsample)\n        x = self.up_x4(x)\n\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * \\\n            self.patches_resolution[0] * \\\n            self.patches_resolution[1] // (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops\n\nclass SwinUnet(nn.Module):\n    def __init__(self, num_classes, img_size, zero_head=False, vis=False):\n        super(SwinUnet, self).__init__()\n        self.num_classes = num_classes\n        self.zero_head = zero_head\n\n        self.swin_unet = SwinTransformerSys(img_size=img_size,\n                                patch_size=4,\n                                num_classes=num_classes,\n                                embed_dim=96,\n                                depths=[2, 2, 6, 2],\n                                num_heads=[3, 6, 12, 24],\n                                window_size=7,\n                                mlp_ratio=4,\n                                qkv_bias=True,\n                                qk_scale=False,\n                                drop_rate=0.0,\n                                drop_path_rate=0.1,\n                                ape=False,\n                                patch_norm=True,\n                                use_checkpoint=False)\n\n    def forward(self, x):\n        if x.size()[1] == 1:\n            x = x.repeat(1,3,1,1)\n        logits = self.swin_unet(x)\n        return logits\n\n    def load_from(self, config):\n        pretrained_path = config.MODEL.PRETRAIN_CKPT\n        if pretrained_path is not None:\n            print(\"pretrained_path:{}\".format(pretrained_path))\n            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n            pretrained_dict = torch.load(pretrained_path, map_location=device)\n            if \"model\"  not in pretrained_dict:\n                print(\"---start load pretrained modle by splitting---\")\n                pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}\n                for k in list(pretrained_dict.keys()):\n                    if \"output\" in k:\n                        print(\"delete key:{}\".format(k))\n                        del pretrained_dict[k]\n                msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)\n                # print(msg)\n                return\n            pretrained_dict = pretrained_dict['model']\n            print(\"---start load pretrained modle of swin encoder---\")\n\n            model_dict = self.swin_unet.state_dict()\n            full_dict = copy.deepcopy(pretrained_dict)\n            for k, v in pretrained_dict.items():\n                if \"layers.\" in k:\n                    current_layer_num = 3-int(k[7:8])\n                    current_k = \"layers_up.\" + str(current_layer_num) + k[8:]\n                    full_dict.update({current_k:v})\n            for k in list(full_dict.keys()):\n                if k in model_dict:\n                    if full_dict[k].shape != model_dict[k].shape:\n                        print(\"delete:{};shape pretrain:{};shape model:{}\".format(k,v.shape,model_dict[k].shape))\n                        del full_dict[k]\n\n            msg = self.swin_unet.load_state_dict(full_dict, strict=False)\n            # print(msg)\n        else:\n            print(\"none pretrain\")\n\n\ndef swinunet(num_classes, img_size):\n    model = SwinUnet(num_classes, img_size=img_size)\n    return model\n\n\nif __name__ == '__main__':\n    model = swinunet(10, 224)\n    model.eval()\n    input = torch.rand(2,1,224,224)\n    output = model(input)\n    output = output.data.cpu().numpy()\n    # print(output)\n    print(output.shape)"
  },
  {
    "path": "models/networks_2d/u2net.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass REBNCONV(nn.Module):\n    def __init__(self,in_ch=3,out_ch=3,dirate=1):\n        super(REBNCONV,self).__init__()\n\n        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)\n        self.bn_s1 = nn.BatchNorm2d(out_ch)\n        self.relu_s1 = nn.ReLU(inplace=True)\n\n    def forward(self,x):\n\n        hx = x\n        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))\n\n        return xout\n\n## upsample tensor 'src' to have the same spatial size with tensor 'tar'\ndef _upsample_like(src,tar):\n\n    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear', align_corners=True)\n\n    return src\n\n\n### RSU-7 ###\nclass RSU7(nn.Module):#UNet07DRES(nn.Module):\n\n    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):\n        super(RSU7,self).__init__()\n\n        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)\n\n        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)\n        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)\n\n        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)\n\n        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)\n\n    def forward(self,x):\n\n        hx = x\n        hxin = self.rebnconvin(hx)\n\n        hx1 = self.rebnconv1(hxin)\n        hx = self.pool1(hx1)\n\n        hx2 = self.rebnconv2(hx)\n        hx = self.pool2(hx2)\n\n        hx3 = self.rebnconv3(hx)\n        hx = self.pool3(hx3)\n\n        hx4 = self.rebnconv4(hx)\n        hx = self.pool4(hx4)\n\n        hx5 = self.rebnconv5(hx)\n        hx = self.pool5(hx5)\n\n        hx6 = self.rebnconv6(hx)\n\n        hx7 = self.rebnconv7(hx6)\n\n        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))\n        hx6dup = _upsample_like(hx6d,hx5)\n\n        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))\n        hx5dup = _upsample_like(hx5d,hx4)\n\n        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))\n        hx4dup = _upsample_like(hx4d,hx3)\n\n        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))\n\n        return hx1d + hxin\n\n### RSU-6 ###\nclass RSU6(nn.Module):#UNet06DRES(nn.Module):\n\n    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):\n        super(RSU6,self).__init__()\n\n        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)\n\n        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)\n        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)\n\n        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)\n\n        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)\n\n    def forward(self,x):\n\n        hx = x\n\n        hxin = self.rebnconvin(hx)\n\n        hx1 = self.rebnconv1(hxin)\n        hx = self.pool1(hx1)\n\n        hx2 = self.rebnconv2(hx)\n        hx = self.pool2(hx2)\n\n        hx3 = self.rebnconv3(hx)\n        hx = self.pool3(hx3)\n\n        hx4 = self.rebnconv4(hx)\n        hx = self.pool4(hx4)\n\n        hx5 = self.rebnconv5(hx)\n\n        hx6 = self.rebnconv6(hx5)\n\n\n        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))\n        hx5dup = _upsample_like(hx5d,hx4)\n\n        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))\n        hx4dup = _upsample_like(hx4d,hx3)\n\n        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))\n\n        return hx1d + hxin\n\n### RSU-5 ###\nclass RSU5(nn.Module):#UNet05DRES(nn.Module):\n\n    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):\n        super(RSU5,self).__init__()\n\n        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)\n\n        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)\n        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)\n\n        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)\n\n        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)\n\n    def forward(self,x):\n\n        hx = x\n\n        hxin = self.rebnconvin(hx)\n\n        hx1 = self.rebnconv1(hxin)\n        hx = self.pool1(hx1)\n\n        hx2 = self.rebnconv2(hx)\n        hx = self.pool2(hx2)\n\n        hx3 = self.rebnconv3(hx)\n        hx = self.pool3(hx3)\n\n        hx4 = self.rebnconv4(hx)\n\n        hx5 = self.rebnconv5(hx4)\n\n        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))\n        hx4dup = _upsample_like(hx4d,hx3)\n\n        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))\n\n        return hx1d + hxin\n\n### RSU-4 ###\nclass RSU4(nn.Module):#UNet04DRES(nn.Module):\n\n    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):\n        super(RSU4,self).__init__()\n\n        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)\n\n        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)\n        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)\n        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)\n\n        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)\n\n        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)\n        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)\n\n    def forward(self,x):\n\n        hx = x\n\n        hxin = self.rebnconvin(hx)\n\n        hx1 = self.rebnconv1(hxin)\n        hx = self.pool1(hx1)\n\n        hx2 = self.rebnconv2(hx)\n        hx = self.pool2(hx2)\n\n        hx3 = self.rebnconv3(hx)\n\n        hx4 = self.rebnconv4(hx3)\n\n        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))\n\n        return hx1d + hxin\n\n### RSU-4F ###\nclass RSU4F(nn.Module):#UNet04FRES(nn.Module):\n\n    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):\n        super(RSU4F,self).__init__()\n\n        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)\n\n        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)\n        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)\n        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)\n\n        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)\n\n        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)\n        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)\n        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)\n\n    def forward(self,x):\n\n        hx = x\n\n        hxin = self.rebnconvin(hx)\n\n        hx1 = self.rebnconv1(hxin)\n        hx2 = self.rebnconv2(hx1)\n        hx3 = self.rebnconv3(hx2)\n\n        hx4 = self.rebnconv4(hx3)\n\n        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))\n        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))\n        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))\n\n        return hx1d + hxin\n\n\n##### U^2-Net ####\nclass U2NET(nn.Module):\n\n    def __init__(self,in_ch=3,out_ch=1):\n        super(U2NET,self).__init__()\n\n        self.stage1 = RSU7(in_ch,32,64)\n        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage2 = RSU6(64,32,128)\n        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage3 = RSU5(128,64,256)\n        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage4 = RSU4(256,128,512)\n        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage5 = RSU4F(512,256,512)\n        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage6 = RSU4F(512,256,512)\n\n        # decoder\n        self.stage5d = RSU4F(1024,256,512)\n        self.stage4d = RSU4(1024,128,256)\n        self.stage3d = RSU5(512,64,128)\n        self.stage2d = RSU6(256,32,64)\n        self.stage1d = RSU7(128,16,64)\n\n        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)\n        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)\n        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)\n        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)\n\n        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)\n\n    def forward(self,x):\n\n        hx = x\n\n        #stage 1\n        hx1 = self.stage1(hx)\n        hx = self.pool12(hx1)\n\n        #stage 2\n        hx2 = self.stage2(hx)\n        hx = self.pool23(hx2)\n\n        #stage 3\n        hx3 = self.stage3(hx)\n        hx = self.pool34(hx3)\n\n        #stage 4\n        hx4 = self.stage4(hx)\n        hx = self.pool45(hx4)\n\n        #stage 5\n        hx5 = self.stage5(hx)\n        hx = self.pool56(hx5)\n\n        #stage 6\n        hx6 = self.stage6(hx)\n        hx6up = _upsample_like(hx6,hx5)\n\n        #-------------------- decoder --------------------\n        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))\n        hx5dup = _upsample_like(hx5d,hx4)\n\n        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))\n        hx4dup = _upsample_like(hx4d,hx3)\n\n        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))\n\n\n        #side output\n        d1 = self.side1(hx1d)\n\n        d2 = self.side2(hx2d)\n        d2 = _upsample_like(d2,d1)\n\n        d3 = self.side3(hx3d)\n        d3 = _upsample_like(d3,d1)\n\n        d4 = self.side4(hx4d)\n        d4 = _upsample_like(d4,d1)\n\n        d5 = self.side5(hx5d)\n        d5 = _upsample_like(d5,d1)\n\n        d6 = self.side6(hx6)\n        d6 = _upsample_like(d6,d1)\n\n        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))\n\n        return d0, d1, d2, d3, d4, d5, d6\n\n### U^2-Net small ###\nclass U2NETP(nn.Module):\n\n    def __init__(self,in_ch=3,out_ch=1):\n        super(U2NETP,self).__init__()\n\n        self.stage1 = RSU7(in_ch,16,64)\n        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage2 = RSU6(64,16,64)\n        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage3 = RSU5(64,16,64)\n        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage4 = RSU4(64,16,64)\n        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage5 = RSU4F(64,16,64)\n        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)\n\n        self.stage6 = RSU4F(64,16,64)\n\n        # decoder\n        self.stage5d = RSU4F(128,16,64)\n        self.stage4d = RSU4(128,16,64)\n        self.stage3d = RSU5(128,16,64)\n        self.stage2d = RSU6(128,16,64)\n        self.stage1d = RSU7(128,16,64)\n\n        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side3 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side4 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side5 = nn.Conv2d(64,out_ch,3,padding=1)\n        self.side6 = nn.Conv2d(64,out_ch,3,padding=1)\n\n        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)\n\n    def forward(self,x):\n\n        hx = x\n\n        #stage 1\n        hx1 = self.stage1(hx)\n        hx = self.pool12(hx1)\n\n        #stage 2\n        hx2 = self.stage2(hx)\n        hx = self.pool23(hx2)\n\n        #stage 3\n        hx3 = self.stage3(hx)\n        hx = self.pool34(hx3)\n\n        #stage 4\n        hx4 = self.stage4(hx)\n        hx = self.pool45(hx4)\n\n        #stage 5\n        hx5 = self.stage5(hx)\n        hx = self.pool56(hx5)\n\n        #stage 6\n        hx6 = self.stage6(hx)\n        hx6up = _upsample_like(hx6,hx5)\n\n        #decoder\n        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))\n        hx5dup = _upsample_like(hx5d,hx4)\n\n        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))\n        hx4dup = _upsample_like(hx4d,hx3)\n\n        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))\n        hx3dup = _upsample_like(hx3d,hx2)\n\n        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))\n        hx2dup = _upsample_like(hx2d,hx1)\n\n        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))\n\n\n        #side output\n        d1 = self.side1(hx1d)\n\n        d2 = self.side2(hx2d)\n        d2 = _upsample_like(d2,d1)\n\n        d3 = self.side3(hx3d)\n        d3 = _upsample_like(d3,d1)\n\n        d4 = self.side4(hx4d)\n        d4 = _upsample_like(d4,d1)\n\n        d5 = self.side5(hx5d)\n        d5 = _upsample_like(d5,d1)\n\n        d6 = self.side6(hx6)\n        d6 = _upsample_like(d6,d1)\n\n        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))\n\n        return d0, d1, d2, d3, d4, d5, d6\n\ndef u2net(in_channels, num_classes):\n    model = U2NET(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef u2net_small(in_channels, num_classes):\n    model = U2NETP(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = u2net(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output = model(input)\n#     output = output[1].data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n\n"
  },
  {
    "path": "models/networks_2d/unet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass conv_block(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(conv_block, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass up_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(up_conv, self).__init__()\n        self.up = nn.Sequential(\n            nn.Upsample(scale_factor=2),\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, x):\n        x = self.up(x)\n        return x\n\n\nclass Recurrent_block(nn.Module):\n    def __init__(self, ch_out, t=2):\n        super(Recurrent_block, self).__init__()\n        self.t = t\n        self.ch_out = ch_out\n        self.conv = nn.Sequential(\n            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, x):\n        for i in range(self.t):\n\n            if i == 0:\n                x1 = self.conv(x)\n\n            x1 = self.conv(x + x1)\n        return x1\n\n\nclass RRCNN_block(nn.Module):\n    def __init__(self, ch_in, ch_out, t=2):\n        super(RRCNN_block, self).__init__()\n        self.RCNN = nn.Sequential(\n            Recurrent_block(ch_out, t=t),\n            Recurrent_block(ch_out, t=t)\n        )\n        self.Conv_1x1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        x = self.Conv_1x1(x)\n        x1 = self.RCNN(x)\n        return x + x1\n\n\nclass single_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(single_conv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            nn.BatchNorm2d(ch_out),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, x):\n        x = self.conv(x)\n        return x\n\n\nclass Attention_block(nn.Module):\n    def __init__(self, F_g, F_l, F_int):\n        super(Attention_block, self).__init__()\n        self.W_g = nn.Sequential(\n            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),\n            nn.BatchNorm2d(F_int)\n        )\n\n        self.W_x = nn.Sequential(\n            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),\n            nn.BatchNorm2d(F_int)\n        )\n\n        self.psi = nn.Sequential(\n            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),\n            nn.BatchNorm2d(1),\n            nn.Sigmoid()\n        )\n\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, g, x):\n        g1 = self.W_g(g)\n        x1 = self.W_x(x)\n        psi = self.relu(g1 + x1)\n        psi = self.psi(psi)\n\n        return x * psi\n\n\nclass U_Net(nn.Module):\n    def __init__(self, in_channels=3, num_classes=1):\n        super(U_Net, self).__init__()\n\n        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n\n        self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)\n        self.Conv2 = conv_block(ch_in=64, ch_out=128)\n        self.Conv3 = conv_block(ch_in=128, ch_out=256)\n        self.Conv4 = conv_block(ch_in=256, ch_out=512)\n        self.Conv5 = conv_block(ch_in=512, ch_out=1024)\n\n        self.Up5 = up_conv(ch_in=1024, ch_out=512)\n        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)\n\n        self.Up4 = up_conv(ch_in=512, ch_out=256)\n        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)\n\n        self.Up3 = up_conv(ch_in=256, ch_out=128)\n        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)\n\n        self.Up2 = up_conv(ch_in=128, ch_out=64)\n        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)\n\n        self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        # encoding path\n        x1 = self.Conv1(x)\n\n        x2 = self.Maxpool(x1)\n        x2 = self.Conv2(x2)\n\n        x3 = self.Maxpool(x2)\n        x3 = self.Conv3(x3)\n\n        x4 = self.Maxpool(x3)\n        x4 = self.Conv4(x4)\n\n        x5 = self.Maxpool(x4)\n        x5 = self.Conv5(x5)\n\n        # decoding + concat path\n        d5 = self.Up5(x5)\n        d5 = torch.cat((x4, d5), dim=1)\n\n        d5 = self.Up_conv5(d5)\n\n        d4 = self.Up4(d5)\n        d4 = torch.cat((x3, d4), dim=1)\n        d4 = self.Up_conv4(d4)\n\n        d3 = self.Up3(d4)\n        d3 = torch.cat((x2, d3), dim=1)\n        d3 = self.Up_conv3(d3)\n\n        d2 = self.Up2(d3)\n        d2 = torch.cat((x1, d2), dim=1)\n        d2 = self.Up_conv2(d2)\n\n        d1 = self.Conv_1x1(d2)\n\n        # outputs = []\n        # outputs.append(d1)\n        # return outputs\n        return d1\n\nclass R2U_Net(nn.Module):\n    def __init__(self, in_channels=3, num_classes=1, t=2):\n        super(R2U_Net, self).__init__()\n\n        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n        self.Upsample = nn.Upsample(scale_factor=2)\n\n        self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t)\n\n        self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)\n\n        self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)\n\n        self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)\n\n        self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)\n\n        self.Up5 = up_conv(ch_in=1024, ch_out=512)\n        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)\n\n        self.Up4 = up_conv(ch_in=512, ch_out=256)\n        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)\n\n        self.Up3 = up_conv(ch_in=256, ch_out=128)\n        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)\n\n        self.Up2 = up_conv(ch_in=128, ch_out=64)\n        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)\n\n        self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        # encoding path\n        x1 = self.RRCNN1(x)\n\n        x2 = self.Maxpool(x1)\n        x2 = self.RRCNN2(x2)\n\n        x3 = self.Maxpool(x2)\n        x3 = self.RRCNN3(x3)\n\n        x4 = self.Maxpool(x3)\n        x4 = self.RRCNN4(x4)\n\n        x5 = self.Maxpool(x4)\n        x5 = self.RRCNN5(x5)\n\n        # decoding + concat path\n        d5 = self.Up5(x5)\n        d5 = torch.cat((x4, d5), dim=1)\n        d5 = self.Up_RRCNN5(d5)\n\n        d4 = self.Up4(d5)\n        d4 = torch.cat((x3, d4), dim=1)\n        d4 = self.Up_RRCNN4(d4)\n\n        d3 = self.Up3(d4)\n        d3 = torch.cat((x2, d3), dim=1)\n        d3 = self.Up_RRCNN3(d3)\n\n        d2 = self.Up2(d3)\n        d2 = torch.cat((x1, d2), dim=1)\n        d2 = self.Up_RRCNN2(d2)\n\n        d1 = self.Conv_1x1(d2)\n\n        # outputs = []\n        # outputs.append(d1)\n        # return outputs\n        return d1\n\nclass AttU_Net(nn.Module):\n    def __init__(self, in_channels=3, num_classes=1):\n        super(AttU_Net, self).__init__()\n\n        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n\n        self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)\n        self.Conv2 = conv_block(ch_in=64, ch_out=128)\n        self.Conv3 = conv_block(ch_in=128, ch_out=256)\n        self.Conv4 = conv_block(ch_in=256, ch_out=512)\n        self.Conv5 = conv_block(ch_in=512, ch_out=1024)\n\n        self.Up5 = up_conv(ch_in=1024, ch_out=512)\n        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)\n        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)\n\n        self.Up4 = up_conv(ch_in=512, ch_out=256)\n        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)\n        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)\n\n        self.Up3 = up_conv(ch_in=256, ch_out=128)\n        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)\n        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)\n\n        self.Up2 = up_conv(ch_in=128, ch_out=64)\n        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)\n        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)\n\n        self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        # encoding path\n        x1 = self.Conv1(x)\n\n        x2 = self.Maxpool(x1)\n        x2 = self.Conv2(x2)\n\n        x3 = self.Maxpool(x2)\n        x3 = self.Conv3(x3)\n\n        x4 = self.Maxpool(x3)\n        x4 = self.Conv4(x4)\n\n        x5 = self.Maxpool(x4)\n        x5 = self.Conv5(x5)\n\n        # decoding + concat path\n        d5 = self.Up5(x5)\n        x4 = self.Att5(g=d5, x=x4)\n        d5 = torch.cat((x4, d5), dim=1)\n        d5 = self.Up_conv5(d5)\n\n        d4 = self.Up4(d5)\n        x3 = self.Att4(g=d4, x=x3)\n        d4 = torch.cat((x3, d4), dim=1)\n        d4 = self.Up_conv4(d4)\n\n        d3 = self.Up3(d4)\n        x2 = self.Att3(g=d3, x=x2)\n        d3 = torch.cat((x2, d3), dim=1)\n        d3 = self.Up_conv3(d3)\n\n        d2 = self.Up2(d3)\n        x1 = self.Att2(g=d2, x=x1)\n        d2 = torch.cat((x1, d2), dim=1)\n        d2 = self.Up_conv2(d2)\n\n        d1 = self.Conv_1x1(d2)\n        # outputs = []\n        # outputs.append(d1)\n        # return outputs\n        return d1\n\n\nclass R2AttU_Net(nn.Module):\n    def __init__(self, in_channels=3, num_classes=1, t=2):\n        super(R2AttU_Net, self).__init__()\n\n        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n        self.Upsample = nn.Upsample(scale_factor=2)\n\n        self.RRCNN1 = RRCNN_block(ch_in=in_channels, ch_out=64, t=t)\n\n        self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)\n\n        self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)\n\n        self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)\n\n        self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)\n\n        self.Up5 = up_conv(ch_in=1024, ch_out=512)\n        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)\n        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)\n\n        self.Up4 = up_conv(ch_in=512, ch_out=256)\n        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)\n        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)\n\n        self.Up3 = up_conv(ch_in=256, ch_out=128)\n        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)\n        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)\n\n        self.Up2 = up_conv(ch_in=128, ch_out=64)\n        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)\n        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)\n\n        self.Conv_1x1 = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        # encoding path\n        x1 = self.RRCNN1(x)\n\n        x2 = self.Maxpool(x1)\n        x2 = self.RRCNN2(x2)\n\n        x3 = self.Maxpool(x2)\n        x3 = self.RRCNN3(x3)\n\n        x4 = self.Maxpool(x3)\n        x4 = self.RRCNN4(x4)\n\n        x5 = self.Maxpool(x4)\n        x5 = self.RRCNN5(x5)\n\n        # decoding + concat path\n        d5 = self.Up5(x5)\n        x4 = self.Att5(g=d5, x=x4)\n        d5 = torch.cat((x4, d5), dim=1)\n        d5 = self.Up_RRCNN5(d5)\n\n        d4 = self.Up4(d5)\n        x3 = self.Att4(g=d4, x=x3)\n        d4 = torch.cat((x3, d4), dim=1)\n        d4 = self.Up_RRCNN4(d4)\n\n        d3 = self.Up3(d4)\n        x2 = self.Att3(g=d3, x=x2)\n        d3 = torch.cat((x2, d3), dim=1)\n        d3 = self.Up_RRCNN3(d3)\n\n        d2 = self.Up2(d3)\n        x1 = self.Att2(g=d2, x=x1)\n        d2 = torch.cat((x1, d2), dim=1)\n        d2 = self.Up_RRCNN2(d2)\n\n        d1 = self.Conv_1x1(d2)\n        # outputs = []\n        # outputs.append(d1)\n        # return outputs\n        return d1\n\n\ndef unet(in_channels, num_classes):\n    model = U_Net(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef r2_unet(in_channels, num_classes):\n    model = R2U_Net(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef attention_unet(in_channels, num_classes):\n    model = AttU_Net(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef r2_attention_unet(in_channels, num_classes):\n    model = R2AttU_Net(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = U_Net(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output = model(input)\n#     output = output[0].data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/unet_3plus.py",
    "content": "# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport numpy as np\n\ndef weights_init_normal(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('Linear') != -1:\n        init.normal_(m.weight.data, 0.0, 0.02)\n    elif classname.find('BatchNorm') != -1:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef weights_init_xavier(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.xavier_normal_(m.weight.data, gain=1)\n    elif classname.find('Linear') != -1:\n        init.xavier_normal_(m.weight.data, gain=1)\n    elif classname.find('BatchNorm') != -1:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef weights_init_kaiming(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('Linear') != -1:\n        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n    elif classname.find('BatchNorm') != -1:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef weights_init_orthogonal(m):\n    classname = m.__class__.__name__\n    #print(classname)\n    if classname.find('Conv') != -1:\n        init.orthogonal_(m.weight.data, gain=1)\n    elif classname.find('Linear') != -1:\n        init.orthogonal_(m.weight.data, gain=1)\n    elif classname.find('BatchNorm') != -1:\n        init.normal_(m.weight.data, 1.0, 0.02)\n        init.constant_(m.bias.data, 0.0)\n\n\ndef init_weights(net, init_type='normal'):\n    #print('initialization method [%s]' % init_type)\n    if init_type == 'normal':\n        net.apply(weights_init_normal)\n    elif init_type == 'xavier':\n        net.apply(weights_init_xavier)\n    elif init_type == 'kaiming':\n        net.apply(weights_init_kaiming)\n    elif init_type == 'orthogonal':\n        net.apply(weights_init_orthogonal)\n    else:\n        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n\n\nclass unetConv2(nn.Module):\n    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):\n        super(unetConv2, self).__init__()\n        self.n = n\n        self.ks = ks\n        self.stride = stride\n        self.padding = padding\n        s = stride\n        p = padding\n        if is_batchnorm:\n            for i in range(1, n + 1):\n                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),\n                                     nn.BatchNorm2d(out_size),\n                                     nn.ReLU(inplace=True), )\n                setattr(self, 'conv%d' % i, conv)\n                in_size = out_size\n\n        else:\n            for i in range(1, n + 1):\n                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),\n                                     nn.ReLU(inplace=True), )\n                setattr(self, 'conv%d' % i, conv)\n                in_size = out_size\n\n        # initialise the blocks\n        for m in self.children():\n            init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs):\n        x = inputs\n        for i in range(1, self.n + 1):\n            conv = getattr(self, 'conv%d' % i)\n            x = conv(x)\n\n        return x\n\n\n'''\n    UNet 3+\n'''\nclass UNet_3Plus(nn.Module):\n\n    def __init__(self, in_channels, num_classes):\n        super(UNet_3Plus, self).__init__()\n        feature_scale = 4\n        is_deconv = True\n        is_batchnorm = True\n        self.is_deconv = is_deconv\n        self.in_channels = in_channels\n        self.is_batchnorm = is_batchnorm\n        self.feature_scale = feature_scale\n\n        filters = [16, 32, 64, 128, 256]\n\n        ## -------------Encoder--------------\n        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)\n        self.maxpool1 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)\n        self.maxpool2 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)\n        self.maxpool3 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)\n        self.maxpool4 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)\n\n        ## -------------Decoder--------------\n        self.CatChannels = filters[0]\n        self.CatBlocks = 5\n        self.UpChannels = self.CatChannels * self.CatBlocks\n\n        '''stage 4d'''\n        # h1->320*320, hd4->40*40, Pooling 8 times\n        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)\n        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd4->40*40, Pooling 4 times\n        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd4->40*40, Pooling 2 times\n        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h4->40*40, hd4->40*40, Concatenation\n        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)\n        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->40*40, Upsample 2 times\n        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)\n        self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu4d_1 = nn.ReLU(inplace=True)\n\n        '''stage 3d'''\n        # h1->320*320, hd3->80*80, Pooling 4 times\n        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd3->80*80, Pooling 2 times\n        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd3->80*80, Concatenation\n        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd4->80*80, Upsample 2 times\n        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->80*80, Upsample 4 times\n        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)\n        self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu3d_1 = nn.ReLU(inplace=True)\n\n        '''stage 2d '''\n        # h1->320*320, hd2->160*160, Pooling 2 times\n        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd2->160*160, Concatenation\n        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd2->160*160, Upsample 2 times\n        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd2->160*160, Upsample 4 times\n        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd2->160*160, Upsample 8 times\n        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)\n        self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu2d_1 = nn.ReLU(inplace=True)\n\n        '''stage 1d'''\n        # h1->320*320, hd1->320*320, Concatenation\n        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd2->160*160, hd1->320*320, Upsample 2 times\n        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd1->320*320, Upsample 4 times\n        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd1->320*320, Upsample 8 times\n        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd1->320*320, Upsample 16 times\n        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)\n        self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu1d_1 = nn.ReLU(inplace=True)\n\n        # output\n        self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)\n\n        # initialise weights\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init_weights(m, init_type='kaiming')\n            elif isinstance(m, nn.BatchNorm2d):\n                init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs):\n        ## -------------Encoder-------------\n        h1 = self.conv1(inputs)  # h1->320*320*64\n\n        h2 = self.maxpool1(h1)\n        h2 = self.conv2(h2)  # h2->160*160*128\n\n        h3 = self.maxpool2(h2)\n        h3 = self.conv3(h3)  # h3->80*80*256\n\n        h4 = self.maxpool3(h3)\n        h4 = self.conv4(h4)  # h4->40*40*512\n\n        h5 = self.maxpool4(h4)\n        hd5 = self.conv5(h5)  # h5->20*20*1024\n\n        ## -------------Decoder-------------\n        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))\n        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))\n        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))\n        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))\n        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))\n        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(\n            torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1))))  # hd4->40*40*UpChannels\n\n        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))\n        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))\n        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))\n        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))\n        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))\n        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(\n            torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1))))  # hd3->80*80*UpChannels\n\n        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))\n        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))\n        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))\n        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))\n        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))\n        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(\n            torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1))))  # hd2->160*160*UpChannels\n\n        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))\n        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))\n        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))\n        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))\n        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))\n        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(\n            torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1))))  # hd1->320*320*UpChannels\n\n        d1 = self.outconv1(hd1)  # d1->320*320*n_classes\n        return d1\n\n\n'''\n    UNet 3+ with deep supervision\n'''\n\n\nclass UNet_3Plus_DeepSup(nn.Module):\n    def __init__(self, in_channels=3, num_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):\n        super(UNet_3Plus_DeepSup, self).__init__()\n        self.is_deconv = is_deconv\n        self.in_channels = in_channels\n        self.is_batchnorm = is_batchnorm\n        self.feature_scale = feature_scale\n\n        filters = [32, 64, 128, 256, 512]\n\n        ## -------------Encoder--------------\n        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)\n        self.maxpool1 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)\n        self.maxpool2 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)\n        self.maxpool3 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)\n        self.maxpool4 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)\n\n        ## -------------Decoder--------------\n        self.CatChannels = filters[0]\n        self.CatBlocks = 5\n        self.UpChannels = self.CatChannels * self.CatBlocks\n\n        '''stage 4d'''\n        # h1->320*320, hd4->40*40, Pooling 8 times\n        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)\n        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd4->40*40, Pooling 4 times\n        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd4->40*40, Pooling 2 times\n        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h4->40*40, hd4->40*40, Concatenation\n        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)\n        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->40*40, Upsample 2 times\n        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)\n        self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu4d_1 = nn.ReLU(inplace=True)\n\n        '''stage 3d'''\n        # h1->320*320, hd3->80*80, Pooling 4 times\n        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd3->80*80, Pooling 2 times\n        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd3->80*80, Concatenation\n        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd4->80*80, Upsample 2 times\n        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->80*80, Upsample 4 times\n        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)\n        self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu3d_1 = nn.ReLU(inplace=True)\n\n        '''stage 2d '''\n        # h1->320*320, hd2->160*160, Pooling 2 times\n        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd2->160*160, Concatenation\n        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd2->160*160, Upsample 2 times\n        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd2->160*160, Upsample 4 times\n        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd2->160*160, Upsample 8 times\n        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)\n        self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu2d_1 = nn.ReLU(inplace=True)\n\n        '''stage 1d'''\n        # h1->320*320, hd1->320*320, Concatenation\n        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd2->160*160, hd1->320*320, Upsample 2 times\n        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd1->320*320, Upsample 4 times\n        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd1->320*320, Upsample 8 times\n        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd1->320*320, Upsample 16 times\n        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)\n        self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu1d_1 = nn.ReLU(inplace=True)\n\n        # -------------Bilinear Upsampling--------------\n        self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)  ###\n        self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)\n        self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)\n        self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)\n        self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n\n        # DeepSup\n        self.outconv1 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)\n        self.outconv2 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)\n        self.outconv3 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)\n        self.outconv4 = nn.Conv2d(self.UpChannels, num_classes, 3, padding=1)\n        self.outconv5 = nn.Conv2d(filters[4], num_classes, 3, padding=1)\n\n        # initialise weights\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init_weights(m, init_type='kaiming')\n            elif isinstance(m, nn.BatchNorm2d):\n                init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs):\n        ## -------------Encoder-------------\n        h1 = self.conv1(inputs)  # h1->320*320*64\n\n        h2 = self.maxpool1(h1)\n        h2 = self.conv2(h2)  # h2->160*160*128\n\n        h3 = self.maxpool2(h2)\n        h3 = self.conv3(h3)  # h3->80*80*256\n\n        h4 = self.maxpool3(h3)\n        h4 = self.conv4(h4)  # h4->40*40*512\n\n        h5 = self.maxpool4(h4)\n        hd5 = self.conv5(h5)  # h5->20*20*1024\n\n        ## -------------Decoder-------------\n        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))\n        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))\n        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))\n        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))\n        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))\n        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(\n            torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1))))  # hd4->40*40*UpChannels\n\n        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))\n        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))\n        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))\n        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))\n        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))\n        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(\n            torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1))))  # hd3->80*80*UpChannels\n\n        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))\n        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))\n        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))\n        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))\n        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))\n        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(\n            torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1))))  # hd2->160*160*UpChannels\n\n        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))\n        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))\n        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))\n        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))\n        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))\n        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(\n            torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1))))  # hd1->320*320*UpChannels\n\n        d5 = self.outconv5(hd5)\n        d5 = self.upscore5(d5)  # 16->256\n\n        d4 = self.outconv4(hd4)\n        d4 = self.upscore4(d4)  # 32->256\n\n        d3 = self.outconv3(hd3)\n        d3 = self.upscore3(d3)  # 64->256\n\n        d2 = self.outconv2(hd2)\n        d2 = self.upscore2(d2)  # 128->256\n\n        d1 = self.outconv1(hd1)  # 256\n        return d1, d2, d3, d4, d5\n\n\n'''\n    UNet 3+ with deep supervision and class-guided module\n'''\n\n\nclass UNet_3Plus_DeepSup_CGM(nn.Module):\n\n    def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):\n        super(UNet_3Plus_DeepSup_CGM, self).__init__()\n        self.is_deconv = is_deconv\n        self.in_channels = in_channels\n        self.is_batchnorm = is_batchnorm\n        self.feature_scale = feature_scale\n\n        filters = [64, 128, 256, 512, 1024]\n\n        ## -------------Encoder--------------\n        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)\n        self.maxpool1 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)\n        self.maxpool2 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)\n        self.maxpool3 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)\n        self.maxpool4 = nn.MaxPool2d(kernel_size=2)\n\n        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)\n\n        ## -------------Decoder--------------\n        self.CatChannels = filters[0]\n        self.CatBlocks = 5\n        self.UpChannels = self.CatChannels * self.CatBlocks\n\n        '''stage 4d'''\n        # h1->320*320, hd4->40*40, Pooling 8 times\n        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)\n        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd4->40*40, Pooling 4 times\n        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd4->40*40, Pooling 2 times\n        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)\n\n        # h4->40*40, hd4->40*40, Concatenation\n        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)\n        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->40*40, Upsample 2 times\n        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)\n        self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu4d_1 = nn.ReLU(inplace=True)\n\n        '''stage 3d'''\n        # h1->320*320, hd3->80*80, Pooling 4 times\n        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)\n        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd3->80*80, Pooling 2 times\n        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)\n\n        # h3->80*80, hd3->80*80, Concatenation\n        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)\n        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd4->80*80, Upsample 2 times\n        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd4->80*80, Upsample 4 times\n        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)\n        self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu3d_1 = nn.ReLU(inplace=True)\n\n        '''stage 2d '''\n        # h1->320*320, hd2->160*160, Pooling 2 times\n        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)\n        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)\n\n        # h2->160*160, hd2->160*160, Concatenation\n        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)\n        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd2->160*160, Upsample 2 times\n        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd2->160*160, Upsample 4 times\n        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd2->160*160, Upsample 8 times\n        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)\n        self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu2d_1 = nn.ReLU(inplace=True)\n\n        '''stage 1d'''\n        # h1->320*320, hd1->320*320, Concatenation\n        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)\n        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd2->160*160, hd1->320*320, Upsample 2 times\n        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # 14*14\n        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd3->80*80, hd1->320*320, Upsample 4 times\n        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)  # 14*14\n        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd4->40*40, hd1->320*320, Upsample 8 times\n        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)  # 14*14\n        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)\n        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # hd5->20*20, hd1->320*320, Upsample 16 times\n        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)  # 14*14\n        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)\n        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)\n        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)\n\n        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)\n        self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1)  # 16\n        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)\n        self.relu1d_1 = nn.ReLU(inplace=True)\n\n        # -------------Bilinear Upsampling--------------\n        self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)  ###\n        self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)\n        self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)\n        self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)\n        self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n\n        # DeepSup\n        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)\n        self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)\n        self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)\n        self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)\n        self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1)\n\n        self.cls = nn.Sequential(\n            nn.Dropout(p=0.5),\n            nn.Conv2d(filters[4], 2, 1),\n            nn.AdaptiveMaxPool2d(1),\n            nn.Sigmoid())\n\n        # initialise weights\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init_weights(m, init_type='kaiming')\n            elif isinstance(m, nn.BatchNorm2d):\n                init_weights(m, init_type='kaiming')\n\n    def dotProduct(self, seg, cls):\n        B, N, H, W = seg.size()\n        seg = seg.view(B, N, H * W)\n        final = torch.einsum(\"ijk,ij->ijk\", [seg, cls])\n        final = final.view(B, N, H, W)\n        return final\n\n    def forward(self, inputs):\n        ## -------------Encoder-------------\n        h1 = self.conv1(inputs)  # h1->320*320*64\n\n        h2 = self.maxpool1(h1)\n        h2 = self.conv2(h2)  # h2->160*160*128\n\n        h3 = self.maxpool2(h2)\n        h3 = self.conv3(h3)  # h3->80*80*256\n\n        h4 = self.maxpool3(h3)\n        h4 = self.conv4(h4)  # h4->40*40*512\n\n        h5 = self.maxpool4(h4)\n        hd5 = self.conv5(h5)  # h5->20*20*1024\n\n        # -------------Classification-------------\n        cls_branch = self.cls(hd5).squeeze(3).squeeze(2)  # (B,N,1,1)->(B,N)\n        cls_branch_max = cls_branch.argmax(dim=1)\n        cls_branch_max = cls_branch_max[:, np.newaxis].float()\n\n        ## -------------Decoder-------------\n        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))\n        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))\n        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))\n        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))\n        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))\n        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1))))  # hd4->40*40*UpChannels\n\n        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))\n        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))\n        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))\n        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))\n        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))\n        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1))))  # hd3->80*80*UpChannels\n\n        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))\n        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))\n        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))\n        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))\n        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))\n        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1))))  # hd2->160*160*UpChannels\n\n        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))\n        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))\n        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))\n        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))\n        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))\n        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1))))  # hd1->320*320*UpChannels\n\n        d5 = self.outconv5(hd5)\n        d5 = self.upscore5(d5)  # 16->256\n\n        d4 = self.outconv4(hd4)\n        d4 = self.upscore4(d4)  # 32->256\n\n        d3 = self.outconv3(hd3)\n        d3 = self.upscore3(d3)  # 64->256\n\n        d2 = self.outconv2(hd2)\n        d2 = self.upscore2(d2)  # 128->256\n\n        d1 = self.outconv1(hd1)  # 256\n\n        d1 = self.dotProduct(d1, cls_branch_max)\n        d2 = self.dotProduct(d2, cls_branch_max)\n        d3 = self.dotProduct(d3, cls_branch_max)\n        d4 = self.dotProduct(d4, cls_branch_max)\n        d5 = self.dotProduct(d5, cls_branch_max)\n\n        return d1, d2, d3, d4, d5,\n\n\ndef unet_3plus(in_channels, num_classes):\n    model = UNet_3Plus(in_channels, num_classes)\n    return model\n\ndef unet_3plus_ds(in_channels, num_classes):\n    model = UNet_3Plus_DeepSup(in_channels, num_classes)\n    return model\n\ndef unet_3plus_ds_cgm(in_channels, num_classes):\n    model = UNet_3Plus_DeepSup_CGM(in_channels, num_classes)\n    return model\n\n\nif __name__ == '__main__':\n    model = unet_3plus_ds_cgm(1,10)\n    model.eval()\n    input = torch.rand(2,1,128,128)\n    output = model(input)\n    output = output[0].data.cpu().numpy()\n    # print(output)\n    print(output.shape)\n\n"
  },
  {
    "path": "models/networks_2d/unet_cct.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.distributions.uniform import Uniform\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass ConvBlock(nn.Module):\n    \"\"\"two convolution layers with batch norm and leaky relu\"\"\"\n\n    def __init__(self, in_channels, out_channels, dropout_p):\n        super(ConvBlock, self).__init__()\n        self.conv_conv = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.LeakyReLU(),\n            nn.Dropout(dropout_p),\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.LeakyReLU()\n        )\n\n    def forward(self, x):\n        return self.conv_conv(x)\n\n\nclass DownBlock(nn.Module):\n    \"\"\"Downsampling followed by ConvBlock\"\"\"\n\n    def __init__(self, in_channels, out_channels, dropout_p):\n        super(DownBlock, self).__init__()\n        self.maxpool_conv = nn.Sequential(\n            nn.MaxPool2d(2),\n            ConvBlock(in_channels, out_channels, dropout_p)\n\n        )\n\n    def forward(self, x):\n        return self.maxpool_conv(x)\n\n\nclass UpBlock(nn.Module):\n    \"\"\"Upssampling followed by ConvBlock\"\"\"\n\n    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,\n                 bilinear=True):\n        super(UpBlock, self).__init__()\n        self.bilinear = bilinear\n        if bilinear:\n            self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)\n            self.up = nn.Upsample(\n                scale_factor=2, mode='bilinear', align_corners=True)\n        else:\n            self.up = nn.ConvTranspose2d(\n                in_channels1, in_channels2, kernel_size=2, stride=2)\n        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)\n\n    def forward(self, x1, x2):\n        if self.bilinear:\n            x1 = self.conv1x1(x1)\n        x1 = self.up(x1)\n        x = torch.cat([x2, x1], dim=1)\n        return self.conv(x)\n\nclass Encoder(nn.Module):\n    def __init__(self, params):\n        super(Encoder, self).__init__()\n        self.params = params\n        self.in_chns = self.params['in_chns']\n        self.ft_chns = self.params['feature_chns']\n        self.n_class = self.params['class_num']\n        self.bilinear = self.params['bilinear']\n        self.dropout = self.params['dropout']\n        assert (len(self.ft_chns) == 5)\n        self.in_conv = ConvBlock(\n            self.in_chns, self.ft_chns[0], self.dropout[0])\n        self.down1 = DownBlock(\n            self.ft_chns[0], self.ft_chns[1], self.dropout[1])\n        self.down2 = DownBlock(\n            self.ft_chns[1], self.ft_chns[2], self.dropout[2])\n        self.down3 = DownBlock(\n            self.ft_chns[2], self.ft_chns[3], self.dropout[3])\n        self.down4 = DownBlock(\n            self.ft_chns[3], self.ft_chns[4], self.dropout[4])\n\n    def forward(self, x):\n        x0 = self.in_conv(x)\n        x1 = self.down1(x0)\n        x2 = self.down2(x1)\n        x3 = self.down3(x2)\n        x4 = self.down4(x3)\n        return [x0, x1, x2, x3, x4]\n\nclass Decoder(nn.Module):\n    def __init__(self, params):\n        super(Decoder, self).__init__()\n        self.params = params\n        self.in_chns = self.params['in_chns']\n        self.ft_chns = self.params['feature_chns']\n        self.n_class = self.params['class_num']\n        self.bilinear = self.params['bilinear']\n        assert (len(self.ft_chns) == 5)\n\n        self.up1 = UpBlock(\n            self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)\n        self.up2 = UpBlock(\n            self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)\n        self.up3 = UpBlock(\n            self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)\n        self.up4 = UpBlock(\n            self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)\n\n        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,\n                                  kernel_size=3, padding=1)\n\n    def forward(self, feature):\n        x0 = feature[0]\n        x1 = feature[1]\n        x2 = feature[2]\n        x3 = feature[3]\n        x4 = feature[4]\n\n        x = self.up1(x4, x3)\n        x = self.up2(x, x2)\n        x = self.up3(x, x1)\n        x = self.up4(x, x0)\n        output = self.out_conv(x)\n        return output\n\n\ndef Dropout(x, p=0.3):\n    x = torch.nn.functional.dropout(x, p)\n    return x\n\n\ndef FeatureDropout(x):\n    attention = torch.mean(x, dim=1, keepdim=True)\n    max_val, _ = torch.max(attention.view(\n        x.size(0), -1), dim=1, keepdim=True)\n    threshold = max_val * np.random.uniform(0.7, 0.9)\n    threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)\n    drop_mask = (attention < threshold).float()\n    x = x.mul(drop_mask)\n    return x\n\n\nclass FeatureNoise(nn.Module):\n    def __init__(self, uniform_range=0.3):\n        super(FeatureNoise, self).__init__()\n        self.uni_dist = Uniform(-uniform_range, uniform_range)\n\n    def feature_based_noise(self, x):\n        noise_vector = self.uni_dist.sample(\n            x.shape[1:]).to(x.device).unsqueeze(0)\n        x_noise = x.mul(noise_vector) + x\n        return x_noise\n\n    def forward(self, x):\n        x = self.feature_based_noise(x)\n        return x\n\nclass UNet_CCT(nn.Module):\n    def __init__(self, in_chns, class_num):\n        super(UNet_CCT, self).__init__()\n\n        params = {'in_chns': in_chns,\n                  'feature_chns': [16, 32, 64, 128, 256],\n                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],\n                  'class_num': class_num,\n                  'bilinear': False,\n                  'acti_func': 'relu'}\n        self.encoder = Encoder(params)\n        self.main_decoder = Decoder(params)\n        self.aux_decoder1 = Decoder(params)\n        self.aux_decoder2 = Decoder(params)\n        self.aux_decoder3 = Decoder(params)\n\n    def forward(self, x):\n        feature = self.encoder(x)\n        main_seg = self.main_decoder(feature)\n        aux1_feature = [FeatureNoise()(i) for i in feature]\n        aux_seg1 = self.aux_decoder1(aux1_feature)\n        aux2_feature = [Dropout(i) for i in feature]\n        aux_seg2 = self.aux_decoder2(aux2_feature)\n        aux3_feature = [FeatureDropout(i) for i in feature]\n        aux_seg3 = self.aux_decoder3(aux3_feature)\n        return main_seg, aux_seg1, aux_seg2, aux_seg3\n\ndef unet_cct(in_channels, num_classes):\n    model = UNet_CCT(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = unet_cct(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output, output1, output2, output3 = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/unet_plusplus.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass VGGBlock(nn.Module):\n    def __init__(self, in_channels, middle_channels, out_channels):\n        super().__init__()\n        self.relu = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)\n        self.bn1 = nn.BatchNorm2d(middle_channels)\n        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)\n        self.bn2 = nn.BatchNorm2d(out_channels)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        return out\n\n\nclass NestedUNet(nn.Module):\n    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):\n        super().__init__()\n\n        nb_filter = [32, 64, 128, 256, 512]\n\n        self.deep_supervision = deep_supervision\n\n        self.pool = nn.MaxPool2d(2, 2)\n        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n\n        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])\n        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])\n        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])\n        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])\n        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])\n\n        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])\n        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])\n        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])\n        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])\n\n        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])\n        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])\n        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])\n\n        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])\n        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])\n\n        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])\n\n        if self.deep_supervision:\n            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)\n            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)\n            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)\n            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)\n        else:\n            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)\n\n\n    def forward(self, input):\n        x0_0 = self.conv0_0(input)\n        x1_0 = self.conv1_0(self.pool(x0_0))\n        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))\n\n        x2_0 = self.conv2_0(self.pool(x1_0))\n        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))\n        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))\n\n        x3_0 = self.conv3_0(self.pool(x2_0))\n        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))\n        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))\n        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))\n\n        x4_0 = self.conv4_0(self.pool(x3_0))\n        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))\n        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))\n        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))\n        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))\n\n        outputs = []\n\n        if self.deep_supervision:\n            output1 = self.final1(x0_1)\n            output2 = self.final2(x0_2)\n            output3 = self.final3(x0_3)\n            output4 = self.final4(x0_4)\n            outputs.append(output1)\n            outputs.append(output2)\n            outputs.append(output3)\n            outputs.append(output4)\n            return outputs\n\n        else:\n            output = self.final(x0_4)\n            # outputs.append(output)\n            # return outputs\n            return output\n\ndef unet_plusplus(in_channels, num_classes):\n    model = NestedUNet(num_classes=num_classes, input_channels=in_channels)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = unet_plusplus(3,10, True)\n#     model.eval()\n#     input = torch.rand(1,3,128,128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/unet_urpc.py",
    "content": "from __future__ import division, print_function\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.distributions.uniform import Uniform\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass ConvBlock(nn.Module):\n    \"\"\"two convolution layers with batch norm and leaky relu\"\"\"\n\n    def __init__(self, in_channels, out_channels, dropout_p):\n        super(ConvBlock, self).__init__()\n        self.conv_conv = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.LeakyReLU(),\n            nn.Dropout(dropout_p),\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.LeakyReLU()\n        )\n\n    def forward(self, x):\n        return self.conv_conv(x)\n\nclass Encoder(nn.Module):\n    def __init__(self, params):\n        super(Encoder, self).__init__()\n        self.params = params\n        self.in_chns = self.params['in_chns']\n        self.ft_chns = self.params['feature_chns']\n        self.n_class = self.params['class_num']\n        self.bilinear = self.params['bilinear']\n        self.dropout = self.params['dropout']\n        assert (len(self.ft_chns) == 5)\n        self.in_conv = ConvBlock(\n            self.in_chns, self.ft_chns[0], self.dropout[0])\n        self.down1 = DownBlock(\n            self.ft_chns[0], self.ft_chns[1], self.dropout[1])\n        self.down2 = DownBlock(\n            self.ft_chns[1], self.ft_chns[2], self.dropout[2])\n        self.down3 = DownBlock(\n            self.ft_chns[2], self.ft_chns[3], self.dropout[3])\n        self.down4 = DownBlock(\n            self.ft_chns[3], self.ft_chns[4], self.dropout[4])\n\n    def forward(self, x):\n        x0 = self.in_conv(x)\n        x1 = self.down1(x0)\n        x2 = self.down2(x1)\n        x3 = self.down3(x2)\n        x4 = self.down4(x3)\n        return [x0, x1, x2, x3, x4]\n\nclass DownBlock(nn.Module):\n    \"\"\"Downsampling followed by ConvBlock\"\"\"\n\n    def __init__(self, in_channels, out_channels, dropout_p):\n        super(DownBlock, self).__init__()\n        self.maxpool_conv = nn.Sequential(\n            nn.MaxPool2d(2),\n            ConvBlock(in_channels, out_channels, dropout_p)\n\n        )\n\n    def forward(self, x):\n        return self.maxpool_conv(x)\n\n\nclass UpBlock(nn.Module):\n    \"\"\"Upssampling followed by ConvBlock\"\"\"\n\n    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,\n                 bilinear=True):\n        super(UpBlock, self).__init__()\n        self.bilinear = bilinear\n        if bilinear:\n            self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)\n            self.up = nn.Upsample(\n                scale_factor=2, mode='bilinear', align_corners=True)\n        else:\n            self.up = nn.ConvTranspose2d(\n                in_channels1, in_channels2, kernel_size=2, stride=2)\n        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)\n\n    def forward(self, x1, x2):\n        if self.bilinear:\n            x1 = self.conv1x1(x1)\n        x1 = self.up(x1)\n        x = torch.cat([x2, x1], dim=1)\n        return self.conv(x)\n\nclass FeatureNoise(nn.Module):\n    def __init__(self, uniform_range=0.3):\n        super(FeatureNoise, self).__init__()\n        self.uni_dist = Uniform(-uniform_range, uniform_range)\n\n    def feature_based_noise(self, x):\n        noise_vector = self.uni_dist.sample(\n            x.shape[1:]).to(x.device).unsqueeze(0)\n        x_noise = x.mul(noise_vector) + x\n        return x_noise\n\n    def forward(self, x):\n        x = self.feature_based_noise(x)\n        return x\n\ndef Dropout(x, p=0.3):\n    x = torch.nn.functional.dropout(x, p)\n    return x\n\n\ndef FeatureDropout(x):\n    attention = torch.mean(x, dim=1, keepdim=True)\n    max_val, _ = torch.max(attention.view(\n        x.size(0), -1), dim=1, keepdim=True)\n    threshold = max_val * np.random.uniform(0.7, 0.9)\n    threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)\n    drop_mask = (attention < threshold).float()\n    x = x.mul(drop_mask)\n    return x\n\nclass Decoder_URPC(nn.Module):\n    def __init__(self, params):\n        super(Decoder_URPC, self).__init__()\n        self.params = params\n        self.in_chns = self.params['in_chns']\n        self.ft_chns = self.params['feature_chns']\n        self.n_class = self.params['class_num']\n        self.bilinear = self.params['bilinear']\n        assert (len(self.ft_chns) == 5)\n\n        self.up1 = UpBlock(\n            self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)\n        self.up2 = UpBlock(\n            self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)\n        self.up3 = UpBlock(\n            self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)\n        self.up4 = UpBlock(\n            self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)\n\n        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,\n                                  kernel_size=3, padding=1)\n        # self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class,\n        #                               kernel_size=3, padding=1)\n        self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class,\n                                      kernel_size=3, padding=1)\n        self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class,\n                                      kernel_size=3, padding=1)\n        self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class,\n                                      kernel_size=3, padding=1)\n        # self.feature_noise = FeatureNoise()\n\n    def forward(self, feature, shape):\n        x0 = feature[0]\n        x1 = feature[1]\n        x2 = feature[2]\n        x3 = feature[3]\n        x4 = feature[4]\n        x = self.up1(x4, x3)\n        if self.training:\n            # dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5))\n            dp3_out_seg = self.out_conv_dp3(x)\n        else:\n            dp3_out_seg = self.out_conv_dp3(x)\n        dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape)\n\n        x = self.up2(x, x2)\n        if self.training:\n            # dp2_out_seg = self.out_conv_dp2(FeatureDropout(x))\n            dp2_out_seg = self.out_conv_dp2(x)\n        else:\n            dp2_out_seg = self.out_conv_dp2(x)\n        dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape)\n\n        x = self.up3(x, x1)\n        if self.training:\n            # dp1_out_seg = self.out_conv_dp1(self.feature_noise(x))\n            dp1_out_seg = self.out_conv_dp1(x)\n        else:\n            dp1_out_seg = self.out_conv_dp1(x)\n        dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape)\n\n        x = self.up4(x, x0)\n        dp0_out_seg = self.out_conv(x)\n        return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg\n\n\n\nclass UNet_URPC(nn.Module):\n    def __init__(self, in_chns, class_num):\n        super(UNet_URPC, self).__init__()\n\n        params = {'in_chns': in_chns,\n                  'feature_chns': [16, 32, 64, 128, 256],\n                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],\n                  'class_num': class_num,\n                  'bilinear': False,\n                  'acti_func': 'relu'}\n        self.encoder = Encoder(params)\n        self.decoder = Decoder_URPC(params)\n\n    def forward(self, x):\n        shape = x.shape[2:]\n        feature = self.encoder(x)\n        dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg = self.decoder(\n            feature, shape)\n        return dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg\n\ndef unet_urpc(in_channels, num_classes):\n    model = UNet_URPC(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = unet_urpc(1,10)\n#     model.eval()\n#     input = torch.rand(2,1,128,128)\n#     output, output1, output2, output3 = model(input)\n#     output = output1.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_2d/wavesnet.py",
    "content": "import numpy as np\nimport math, pywt\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Module\nfrom torch.autograd import Function\nfrom collections import OrderedDict\nfrom itertools import islice\nimport operator\n\nclass My_DownSampling_SC(nn.Module):\n    def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)):\n        super(My_DownSampling_SC, self).__init__()\n        self.conv = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding)\n\n    def forward(self, input):\n        return self.conv(input), input\n\n\nclass My_DownSampling_MP(nn.Module):\n    def __init__(self, stride = 2, kernel_size = 2):\n        super(My_DownSampling_MP, self).__init__()\n        self.maxp = nn.MaxPool2d(kernel_size = kernel_size, stride = stride, return_indices = False)\n\n    def forward(self, input):\n        return self.maxp(input), input\n\n\nclass My_UpSampling_SC(nn.Module):\n    def __init__(self, in_channel, out_channel, kernel_size = (1,1), stride = 2, padding = (0,0)):\n        super(My_UpSampling_SC, self).__init__()\n        self.conv = nn.ConvTranspose2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = stride, padding = padding)\n\n    def forward(self, input, feature_map):\n        return torch.cat((self.conv(input), feature_map), dim = 1)\n\n\nclass My_DownSampling_DWT(nn.Module):\n    def __init__(self, wavename = 'haar'):\n        super(My_DownSampling_DWT, self).__init__()\n        self.dwt = DWT_2D(wavename = wavename)\n\n    def forward(self, input):\n        LL, LH, HL, HH = self.dwt(input)\n        return LL, LH, HL, HH, input\n\n\nclass My_UpSampling_IDWT(nn.Module):\n    def __init__(self, wavename = 'haar'):\n        super(My_UpSampling_IDWT, self).__init__()\n        self.idwt = IDWT_2D(wavename = wavename)\n\n    def forward(self, LL, LH, HL, HH, feature_map):\n        return torch.cat((self.idwt(LL, LH, HL, HH), feature_map), dim = 1)\n\n\n\nclass My_Sequential(Module):\n    r\"\"\"A sequential container.\n    Modules will be added to it in the order they are passed in the constructor.\n    Alternatively, an ordered dict of modules can also be passed in.\n    若某个模块输出多个数据，只将第一个数据往下传\n    \"\"\"\n\n    def __init__(self, *args):\n        super(My_Sequential, self).__init__()\n        if len(args) == 1 and isinstance(args[0], OrderedDict):\n            for key, module in args[0].items():\n                self.add_module(key, module)\n        else:\n            for idx, module in enumerate(args):\n                self.add_module(str(idx), module)\n\n    def _get_item_by_idx(self, iterator, idx):\n        \"\"\"Get the idx-th item of the iterator\"\"\"\n        size = len(self)\n        idx = operator.index(idx)\n        if not -size <= idx < size:\n            raise IndexError('index {} is out of range'.format(idx))\n        idx %= size\n        return next(islice(iterator, idx, None))\n\n    def __getitem__(self, idx):\n        if isinstance(idx, slice):\n            return self.__class__(OrderedDict(list(self._modules.items())[idx]))\n        else:\n            return self._get_item_by_idx(self._modules.values(), idx)\n\n    def __setitem__(self, idx, module):\n        key = self._get_item_by_idx(self._modules.keys(), idx)\n        return setattr(self, key, module)\n\n    def __delitem__(self, idx):\n        if isinstance(idx, slice):\n            for key in list(self._modules.keys())[idx]:\n                delattr(self, key)\n        else:\n            key = self._get_item_by_idx(self._modules.keys(), idx)\n            delattr(self, key)\n\n    def __len__(self):\n        return len(self._modules)\n\n    def __dir__(self):\n        keys = super(My_Sequential, self).__dir__()\n        keys = [key for key in keys if not key.isdigit()]\n        return keys\n\n    def forward(self, input):\n        self.output = []\n        for module in self._modules.values():\n            input = module(input)\n            if isinstance(input, tuple):\n                assert len(input) == 4 or len(input) == 2 or len(input) == 5\n                self.output.append(input[1:])\n                input = input[0]\n        if self.output != []:\n            return input, self.output\n        else:\n            return input\n\n\nclass My_Sequential_re(Module):\n    r\"\"\"A sequential container.\n    Modules will be added to it in the order they are passed in the constructor.\n    Alternatively, an ordered dict of modules can also be passed in.\n    若某个模块输出多个数据，只将第一个数据往下传\n    \"\"\"\n\n    def __init__(self, *args):\n        super(My_Sequential_re, self).__init__()\n        if len(args) == 1 and isinstance(args[0], OrderedDict):\n            for key, module in args[0].items():\n                self.add_module(key, module)\n        else:\n            for idx, module in enumerate(args):\n                self.add_module(str(idx), module)\n        self.output = []\n\n    def _get_item_by_idx(self, iterator, idx):\n        \"\"\"Get the idx-th item of the iterator\"\"\"\n        size = len(self)\n        idx = operator.index(idx)\n        if not -size <= idx < size:\n            raise IndexError('index {} is out of range'.format(idx))\n        idx %= size\n        return next(islice(iterator, idx, None))\n\n    def __getitem__(self, idx):\n        if isinstance(idx, slice):\n            return self.__class__(OrderedDict(list(self._modules.items())[idx]))\n        else:\n            return self._get_item_by_idx(self._modules.values(), idx)\n\n    def __setitem__(self, idx, module):\n        key = self._get_item_by_idx(self._modules.keys(), idx)\n        return setattr(self, key, module)\n\n    def __delitem__(self, idx):\n        if isinstance(idx, slice):\n            for key in list(self._modules.keys())[idx]:\n                delattr(self, key)\n        else:\n            key = self._get_item_by_idx(self._modules.keys(), idx)\n            delattr(self, key)\n\n    def __len__(self):\n        return len(self._modules)\n\n    def __dir__(self):\n        keys = super(My_Sequential_re, self).__dir__()\n        keys = [key for key in keys if not key.isdigit()]\n        return keys\n\n    def forward(self, *input):\n        LL = input[0]\n        index = 1\n        for module in self._modules.values():\n            if isinstance(module, My_UpSampling_IDWT):\n                LH = input[index]\n                HL = input[index + 1]\n                HH = input[index + 2]\n                feature_map = input[index + 3]\n                LL = module(LL, LH, HL, HH, feature_map = feature_map)\n                index += 4\n            elif isinstance(module, IDWT_2D) or 'idwt' in dir(module):\n                LH = input[index]\n                HL = input[index + 1]\n                HH = input[index + 2]\n                LL = module(LL, LH, HL, HH)\n                index += 3\n            elif isinstance(module, nn.MaxUnpool2d):\n                indices = input[index]\n                LL = module(input = LL, indices = indices)\n                #_, _, h, w = LL.size()\n                #LL = F.interpolate(LL, size = (2*h, 2*w), mode = 'bilinear', align_corners = True)\n                index += 1\n            elif isinstance(module, My_UpSampling_SC):\n                feature_map = input[index]\n                LL = module(input = LL, feature_map = feature_map)\n                index += 1\n            else:\n                LL = module(LL)\n        return LL\n\n\nclass DWTFunction_1D(Function):\n    @staticmethod\n    def forward(ctx, input, matrix_Low, matrix_High):\n        ctx.save_for_backward(matrix_Low, matrix_High)\n        L = torch.matmul(input, matrix_Low.t())\n        H = torch.matmul(input, matrix_High.t())\n        return L, H\n    @staticmethod\n    def backward(ctx, grad_L, grad_H):\n        matrix_L, matrix_H = ctx.saved_variables\n        grad_input = torch.add(torch.matmul(grad_L, matrix_L), torch.matmul(grad_H, matrix_H))\n        return grad_input, None, None\n\n\nclass IDWTFunction_1D(Function):\n    @staticmethod\n    def forward(ctx, input_L, input_H, matrix_L, matrix_H):\n        ctx.save_for_backward(matrix_L, matrix_H)\n        output = torch.add(torch.matmul(input_L, matrix_L), torch.matmul(input_H, matrix_H))\n        return output\n    @staticmethod\n    def backward(ctx, grad_output):\n        matrix_L, matrix_H = ctx.saved_variables\n        grad_L = torch.matmul(grad_output, matrix_L.t())\n        grad_H = torch.matmul(grad_output, matrix_H.t())\n        return grad_L, grad_H, None, None\n\n\nclass DWTFunction_2D(Function):\n    @staticmethod\n    def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):\n        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)\n        L = torch.matmul(matrix_Low_0, input)\n        H = torch.matmul(matrix_High_0, input)\n        LL = torch.matmul(L, matrix_Low_1)\n        LH = torch.matmul(L, matrix_High_1)\n        HL = torch.matmul(H, matrix_Low_1)\n        HH = torch.matmul(H, matrix_High_1)\n        return LL, LH, HL, HH\n    @staticmethod\n    def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH):\n        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors\n        grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))\n        grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))\n        grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))\n        return grad_input, None, None, None, None\n\n\nclass DWTFunction_2D_tiny(Function):\n    @staticmethod\n    def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):\n        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)\n        L = torch.matmul(matrix_Low_0, input)\n        LL = torch.matmul(L, matrix_Low_1)\n        return LL\n    @staticmethod\n    def backward(ctx, grad_LL):\n        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_variables\n        grad_L = torch.matmul(grad_LL, matrix_Low_1.t())\n        grad_input = torch.matmul(matrix_Low_0.t(), grad_L)\n        return grad_input, None, None, None, None\n\n\nclass IDWTFunction_2D(Function):\n    @staticmethod\n    def forward(ctx, input_LL, input_LH, input_HL, input_HH,\n                matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1):\n        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1)\n        L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))\n        H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))\n        output = torch.add(torch.matmul(matrix_Low_0.t(), L), torch.matmul(matrix_High_0.t(), H))\n        return output\n    @staticmethod\n    def backward(ctx, grad_output):\n        matrix_Low_0, matrix_Low_1, matrix_High_0, matrix_High_1 = ctx.saved_tensors\n        grad_L = torch.matmul(matrix_Low_0, grad_output)\n        grad_H = torch.matmul(matrix_High_0, grad_output)\n        grad_LL = torch.matmul(grad_L, matrix_Low_1)\n        grad_LH = torch.matmul(grad_L, matrix_High_1)\n        grad_HL = torch.matmul(grad_H, matrix_Low_1)\n        grad_HH = torch.matmul(grad_H, matrix_High_1)\n        return grad_LL, grad_LH, grad_HL, grad_HH, None, None, None, None\n\n\nclass DWTFunction_3D(Function):\n    @staticmethod\n    def forward(ctx, input,\n                matrix_Low_0, matrix_Low_1, matrix_Low_2,\n                matrix_High_0, matrix_High_1, matrix_High_2):\n        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,\n                              matrix_High_0, matrix_High_1, matrix_High_2)\n        L = torch.matmul(matrix_Low_0, input)\n        H = torch.matmul(matrix_High_0, input)\n        LL = torch.matmul(L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)\n        LH = torch.matmul(L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)\n        HL = torch.matmul(H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)\n        HH = torch.matmul(H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)\n        LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0 = 2, dim1 = 3)\n        LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0 = 2, dim1 = 3)\n        LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0 = 2, dim1 = 3)\n        LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0 = 2, dim1 = 3)\n        HLL = torch.matmul(matrix_High_2, LL).transpose(dim0 = 2, dim1 = 3)\n        HLH = torch.matmul(matrix_High_2, LH).transpose(dim0 = 2, dim1 = 3)\n        HHL = torch.matmul(matrix_High_2, HL).transpose(dim0 = 2, dim1 = 3)\n        HHH = torch.matmul(matrix_High_2, HH).transpose(dim0 = 2, dim1 = 3)\n        return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH\n\n    @staticmethod\n    def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,\n                      grad_HLL, grad_HLH, grad_HHL, grad_HHH):\n        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables\n        grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), grad_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()), torch.matmul(grad_LH, matrix_High_1.t()))\n        grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()), torch.matmul(grad_HH, matrix_High_1.t()))\n        grad_input = torch.add(torch.matmul(matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))\n        return grad_input, None, None, None, None, None, None, None, None\n\n\nclass IDWTFunction_3D(Function):\n    @staticmethod\n    def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,\n                     input_HLL, input_HLH, input_HHL, input_HHH,\n                     matrix_Low_0, matrix_Low_1, matrix_Low_2,\n                     matrix_High_0, matrix_High_1, matrix_High_2):\n        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,\n                              matrix_High_0, matrix_High_1, matrix_High_2)\n        input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HLH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHL.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0 = 2, dim1 = 3)), torch.matmul(matrix_High_2.t(), input_HHH.transpose(dim0 = 2, dim1 = 3))).transpose(dim0 = 2, dim1 = 3)\n        input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()), torch.matmul(input_LH, matrix_High_1.t()))\n        input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()), torch.matmul(input_HH, matrix_High_1.t()))\n        output = torch.add(torch.matmul(matrix_Low_0.t(), input_L), torch.matmul(matrix_High_0.t(), input_H))\n        return output\n    @staticmethod\n    def backward(ctx, grad_output):\n        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables\n        grad_L = torch.matmul(matrix_Low_0, grad_output)\n        grad_H = torch.matmul(matrix_High_0, grad_output)\n        grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)\n        grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0 = 2, dim1 = 3)\n        grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0 = 2, dim1 = 3)\n        grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0 = 2, dim1 = 3)\n        grad_LLL = torch.matmul(matrix_Low_2, grad_LL).transpose(dim0 = 2, dim1 = 3)\n        grad_LLH = torch.matmul(matrix_Low_2, grad_LH).transpose(dim0 = 2, dim1 = 3)\n        grad_LHL = torch.matmul(matrix_Low_2, grad_HL).transpose(dim0 = 2, dim1 = 3)\n        grad_LHH = torch.matmul(matrix_Low_2, grad_HH).transpose(dim0 = 2, dim1 = 3)\n        grad_HLL = torch.matmul(matrix_High_2, grad_LL).transpose(dim0 = 2, dim1 = 3)\n        grad_HLH = torch.matmul(matrix_High_2, grad_LH).transpose(dim0 = 2, dim1 = 3)\n        grad_HHL = torch.matmul(matrix_High_2, grad_HL).transpose(dim0 = 2, dim1 = 3)\n        grad_HHH = torch.matmul(matrix_High_2, grad_HH).transpose(dim0 = 2, dim1 = 3)\n        return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None\n\n\nclass DWT_1D(Module):\n    \"\"\"\n    input: (N, C, L)\n    output: L -- (N, C, L/2)\n            H -- (N, C, L/2)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波分解所用低频滤波器组\n        :param band_high: 小波分解所用高频滤波器组\n        \"\"\"\n        super(DWT_1D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.rec_lo\n        self.band_high = wavelet.rec_hi\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = self.input_height\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_h = matrix_h[:,(self.band_length_half-1):end]\n        matrix_g = matrix_g[:,(self.band_length_half-1):end]\n        if torch.cuda.is_available():\n            self.matrix_low = torch.tensor(matrix_h).cuda()\n            self.matrix_high = torch.tensor(matrix_g).cuda()\n        else:\n            self.matrix_low = torch.tensor(matrix_h)\n            self.matrix_high = torch.tensor(matrix_g)\n\n    def forward(self, input):\n        assert len(input.size()) == 3\n        self.input_height = input.size()[-1]\n        #assert self.input_height > self.band_length\n        self.get_matrix()\n        return DWTFunction_1D.apply(input, self.matrix_low, self.matrix_high)\n\n\nclass IDWT_1D(Module):\n    \"\"\"\n    input:  L -- (N, C, L/2)\n            H -- (N, C, L/2)\n    output: (N, C, L)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波重建所需低频滤波器组\n        :param band_high: 小波重建所需高频滤波器组\n        \"\"\"\n        super(IDWT_1D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.dec_lo\n        self.band_high = wavelet.dec_hi\n        self.band_low.reverse()\n        self.band_high.reverse()\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = self.input_height\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_h = matrix_h[:,(self.band_length_half-1):end]\n        matrix_g = matrix_g[:,(self.band_length_half-1):end]\n        if torch.cuda.is_available():\n            self.matrix_low = torch.tensor(matrix_h).cuda()\n            self.matrix_high = torch.tensor(matrix_g).cuda()\n        else:\n            self.matrix_low = torch.tensor(matrix_h)\n            self.matrix_high = torch.tensor(matrix_g)\n\n    def forward(self, L, H):\n        assert len(L.size()) == len(H.size()) == 3\n        self.input_height = L.size()[-1] + H.size()[-1]\n        #assert self.input_height > self.band_length\n        self.get_matrix()\n        return IDWTFunction_1D.apply(L, H, self.matrix_low, self.matrix_high)\n\n\nclass DWT_2D(Module):\n    \"\"\"\n    input: (N, C, H, W)\n    output -- LL: (N, C, H/2, W/2)\n              LH: (N, C, H/2, W/2)\n              HL: (N, C, H/2, W/2)\n              HH: (N, C, H/2, W/2)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波分解所用低频滤波器组\n        :param band_high: 小波分解所用高频滤波器组\n        \"\"\"\n        super(DWT_2D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.rec_lo\n        self.band_high = wavelet.rec_hi\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = np.max((self.input_height, self.input_width))\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]\n        matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]\n\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]\n        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]\n\n        matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]\n        matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]\n        matrix_h_1 = np.transpose(matrix_h_1)\n        matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]\n        matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]\n        matrix_g_1 = np.transpose(matrix_g_1)\n\n        if torch.cuda.is_available():\n            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()\n            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()\n            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()\n            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()\n        else:\n            self.matrix_low_0 = torch.Tensor(matrix_h_0)\n            self.matrix_low_1 = torch.Tensor(matrix_h_1)\n            self.matrix_high_0 = torch.Tensor(matrix_g_0)\n            self.matrix_high_1 = torch.Tensor(matrix_g_1)\n\n    def forward(self, input):\n        assert isinstance(input, torch.Tensor)\n        assert len(input.size()) == 4\n        self.input_height = input.size()[-2]\n        self.input_width = input.size()[-1]\n        #assert self.input_height > self.band_length and self.input_width > self.band_length\n        self.get_matrix()\n        return DWTFunction_2D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)\n\n\nclass DWT_2D_tiny(Module):\n    \"\"\"\n    input: (N, C, H, W)\n    output -- LL: (N, C, H/2, W/2)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波分解所用低频滤波器组\n        :param band_high: 小波分解所用高频滤波器组\n        \"\"\"\n        super(DWT_2D_tiny, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.rec_lo\n        self.band_high = wavelet.rec_hi\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = np.max((self.input_height, self.input_width))\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]\n        matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]\n\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]\n        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]\n\n        matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]\n        matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]\n        matrix_h_1 = np.transpose(matrix_h_1)\n        matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]\n        matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]\n        matrix_g_1 = np.transpose(matrix_g_1)\n\n        if torch.cuda.is_available():\n            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()\n            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()\n            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()\n            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()\n        else:\n            self.matrix_low_0 = torch.Tensor(matrix_h_0)\n            self.matrix_low_1 = torch.Tensor(matrix_h_1)\n            self.matrix_high_0 = torch.Tensor(matrix_g_0)\n            self.matrix_high_1 = torch.Tensor(matrix_g_1)\n\n    def forward(self, input):\n        assert isinstance(input, torch.Tensor)\n        assert len(input.size()) == 4\n        self.input_height = input.size()[-2]\n        self.input_width = input.size()[-1]\n        self.get_matrix()\n        return DWTFunction_2D_tiny.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)\n\n\nclass IDWT_2D(Module):\n    \"\"\"\n    input -- LL: (N, C, H/2, W/2)\n             LH: (N, C, H/2, W/2)\n             HL: (N, C, H/2, W/2)\n             HH: (N, C, H/2, W/2)\n    output: (N, C, H, W)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波重建所需低频滤波器组\n        :param band_high: 小波重建所需高频滤波器组\n        \"\"\"\n        super(IDWT_2D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.dec_lo\n        self.band_low.reverse()\n        self.band_high = wavelet.dec_hi\n        self.band_high.reverse()\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = np.max((self.input_height, self.input_width))\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]\n        matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]\n\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]\n        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]\n\n        matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]\n        matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]\n        matrix_h_1 = np.transpose(matrix_h_1)\n        matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]\n        matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]\n        matrix_g_1 = np.transpose(matrix_g_1)\n        if torch.cuda.is_available():\n            self.matrix_low_0 = torch.Tensor(matrix_h_0).cuda()\n            self.matrix_low_1 = torch.Tensor(matrix_h_1).cuda()\n            self.matrix_high_0 = torch.Tensor(matrix_g_0).cuda()\n            self.matrix_high_1 = torch.Tensor(matrix_g_1).cuda()\n        else:\n            self.matrix_low_0 = torch.Tensor(matrix_h_0)\n            self.matrix_low_1 = torch.Tensor(matrix_h_1)\n            self.matrix_high_0 = torch.Tensor(matrix_g_0)\n            self.matrix_high_1 = torch.Tensor(matrix_g_1)\n\n    def forward(self, LL, LH, HL, HH):\n        assert len(LL.size()) == len(LH.size()) == len(HL.size()) == len(HH.size()) == 4\n        self.input_height = LL.size()[-2] + HH.size()[-2]\n        self.input_width = LL.size()[-1] + HH.size()[-1]\n        #assert self.input_height > self.band_length and self.input_width > self.band_length\n        self.get_matrix()\n        return IDWTFunction_2D.apply(LL, LH, HL, HH, self.matrix_low_0, self.matrix_low_1, self.matrix_high_0, self.matrix_high_1)\n\n\nclass DWT_3D(Module):\n    \"\"\"\n    input: (N, C, D, H, W)\n    output: -- LLL (N, C, D/2, H/2, W/2)\n            -- LLH (N, C, D/2, H/2, W/2)\n            -- LHL (N, C, D/2, H/2, W/2)\n            -- LHH (N, C, D/2, H/2, W/2)\n            -- HLL (N, C, D/2, H/2, W/2)\n            -- HLH (N, C, D/2, H/2, W/2)\n            -- HHL (N, C, D/2, H/2, W/2)\n            -- HHH (N, C, D/2, H/2, W/2)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波分解所用低频滤波器组\n        :param band_high: 小波分解所用高频滤波器组\n        \"\"\"\n        super(DWT_3D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.rec_lo\n        self.band_high = wavelet.rec_hi\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = np.max((self.input_height, self.input_width))\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]\n        matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]\n        matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]\n\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]\n        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]\n        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)]\n\n        matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]\n        matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]\n        matrix_h_1 = np.transpose(matrix_h_1)\n        matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end]\n\n        matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]\n        matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]\n        matrix_g_1 = np.transpose(matrix_g_1)\n        matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end]\n        if torch.cuda.is_available():\n            self.matrix_low_0 = torch.tensor(matrix_h_0).cuda()\n            self.matrix_low_1 = torch.tensor(matrix_h_1).cuda()\n            self.matrix_low_2 = torch.tensor(matrix_h_2).cuda()\n            self.matrix_high_0 = torch.tensor(matrix_g_0).cuda()\n            self.matrix_high_1 = torch.tensor(matrix_g_1).cuda()\n            self.matrix_high_2 = torch.tensor(matrix_g_2).cuda()\n        else:\n            self.matrix_low_0 = torch.tensor(matrix_h_0)\n            self.matrix_low_1 = torch.tensor(matrix_h_1)\n            self.matrix_low_2 = torch.tensor(matrix_h_2)\n            self.matrix_high_0 = torch.tensor(matrix_g_0)\n            self.matrix_high_1 = torch.tensor(matrix_g_1)\n            self.matrix_high_2 = torch.tensor(matrix_g_2)\n\n    def forward(self, input):\n        assert len(input.size()) == 5\n        self.input_depth = input.size()[-3]\n        self.input_height = input.size()[-2]\n        self.input_width = input.size()[-1]\n        #assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length\n        self.get_matrix()\n        return DWTFunction_3D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,\n                                           self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)\n\n\nclass IDWT_3D(Module):\n    \"\"\"\n    input:  -- LLL (N, C, D/2, H/2, W/2)\n            -- LLH (N, C, D/2, H/2, W/2)\n            -- LHL (N, C, D/2, H/2, W/2)\n            -- LHH (N, C, D/2, H/2, W/2)\n            -- HLL (N, C, D/2, H/2, W/2)\n            -- HLH (N, C, D/2, H/2, W/2)\n            -- HHL (N, C, D/2, H/2, W/2)\n            -- HHH (N, C, D/2, H/2, W/2)\n    output: (N, C, D, H, W)\n    \"\"\"\n    def __init__(self, wavename):\n        \"\"\"\n        :param band_low: 小波重构所用低频滤波器组\n        :param band_high: 小波重构所用高频滤波器组\n        \"\"\"\n        super(IDWT_3D, self).__init__()\n        wavelet = pywt.Wavelet(wavename)\n        self.band_low = wavelet.dec_lo\n        self.band_high = wavelet.dec_hi\n        self.band_low.reverse()\n        self.band_high.reverse()\n        assert len(self.band_low) == len(self.band_high)\n        self.band_length = len(self.band_low)\n        assert self.band_length % 2 == 0\n        self.band_length_half = math.floor(self.band_length / 2)\n\n    def get_matrix(self):\n        \"\"\"\n        生成变换矩阵\n        :return:\n        \"\"\"\n        L1 = np.max((self.input_height, self.input_width))\n        L = math.floor(L1 / 2)\n        matrix_h = np.zeros( ( L,      L1 + self.band_length - 2 ) )\n        matrix_g = np.zeros( ( L1 - L, L1 + self.band_length - 2 ) )\n        end = None if self.band_length_half == 1 else (-self.band_length_half+1)\n\n        index = 0\n        for i in range(L):\n            for j in range(self.band_length):\n                matrix_h[i, index+j] = self.band_low[j]\n            index += 2\n        matrix_h_0 = matrix_h[0:(math.floor(self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]\n        matrix_h_1 = matrix_h[0:(math.floor(self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]\n        matrix_h_2 = matrix_h[0:(math.floor(self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]\n\n        index = 0\n        for i in range(L1 - L):\n            for j in range(self.band_length):\n                matrix_g[i, index+j] = self.band_high[j]\n            index += 2\n        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(self.input_height / 2)),0:(self.input_height + self.band_length - 2)]\n        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(self.input_width / 2)),0:(self.input_width + self.band_length - 2)]\n        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(self.input_depth / 2)),0:(self.input_depth + self.band_length - 2)]\n\n        matrix_h_0 = matrix_h_0[:,(self.band_length_half-1):end]\n        matrix_h_1 = matrix_h_1[:,(self.band_length_half-1):end]\n        matrix_h_1 = np.transpose(matrix_h_1)\n        matrix_h_2 = matrix_h_2[:,(self.band_length_half-1):end]\n\n        matrix_g_0 = matrix_g_0[:,(self.band_length_half-1):end]\n        matrix_g_1 = matrix_g_1[:,(self.band_length_half-1):end]\n        matrix_g_1 = np.transpose(matrix_g_1)\n        matrix_g_2 = matrix_g_2[:,(self.band_length_half-1):end]\n        if torch.cuda.is_available():\n            self.matrix_low_0 = torch.tensor(matrix_h_0).cuda()\n            self.matrix_low_1 = torch.tensor(matrix_h_1).cuda()\n            self.matrix_low_2 = torch.tensor(matrix_h_2).cuda()\n            self.matrix_high_0 = torch.tensor(matrix_g_0).cuda()\n            self.matrix_high_1 = torch.tensor(matrix_g_1).cuda()\n            self.matrix_high_2 = torch.tensor(matrix_g_2).cuda()\n        else:\n            self.matrix_low_0 = torch.tensor(matrix_h_0)\n            self.matrix_low_1 = torch.tensor(matrix_h_1)\n            self.matrix_low_2 = torch.tensor(matrix_h_2)\n            self.matrix_high_0 = torch.tensor(matrix_g_0)\n            self.matrix_high_1 = torch.tensor(matrix_g_1)\n            self.matrix_high_2 = torch.tensor(matrix_g_2)\n\n    def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):\n        assert len(LLL.size()) == len(LLH.size()) == len(LHL.size()) == len(LHH.size()) == 5\n        assert len(HLL.size()) == len(HLH.size()) == len(HHL.size()) == len(HHH.size()) == 5\n        self.input_depth = LLL.size()[-3] + HHH.size()[-3]\n        self.input_height = LLL.size()[-2] + HHH.size()[-2]\n        self.input_width = LLL.size()[-1] + HHH.size()[-1]\n        #assert self.input_height > self.band_length and self.input_width > self.band_length and self.input_depth > self.band_length\n        self.get_matrix()\n        return IDWTFunction_3D.apply(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH,\n                                     self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,\n                                     self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)\n\n\n# if __name__ == '__main__':\n#     import pywt, cv2\n#     from datetime import datetime\n#\n#     wavelet = pywt.Wavelet('bior1.1')\n#     h = wavelet.rec_lo\n#     g = wavelet.rec_hi\n#     h_ = wavelet.dec_lo\n#     g_ = wavelet.dec_hi\n#     h_.reverse()\n#     g_.reverse()\n#\n#     #\"\"\"\n#     image_full_name = '/home/liqiufu/Pictures/standard_test_images/lena_color_512.tif'\n#     image = cv2.imread(image_full_name, flags = 1)\n#     image = image[0:512,0:512,:]\n#     print(image.shape)\n#     height, width, channel = image.shape\n#     #image = image.reshape((1,height,width))\n#     t0 = datetime.now()\n#     for index in range(1):\n#         m0 = DWT_2D(wavename = 'haar')\n#\n#         m1 = IDWT_2D(wavename = 'haar')\n#     print(isinstance(m1, IDWT_2D))\n#     t1 = datetime.now()\n\nclass SegNet_VGG(nn.Module):\n    def __init__(self, features, num_classes = 21, init_weights = True, wavename = None):\n        super(SegNet_VGG, self).__init__()\n        self.features = features[0]\n        self.decoders = features[1]\n        self.classifier_seg = nn.Sequential(\n            #nn.Conv2d(64, 64, kernel_size = 3, padding = 1),\n            #nn.ReLU(True),\n            nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0),\n        )\n        if init_weights:\n            self._initialize_weights()\n\n    def forward(self, x):\n        xx = self.features(x)\n        x, [(indices_1,), (indices_2,), (indices_3,), (indices_4,), (indices_5,)] = xx\n        x = self.decoders(x, indices_5, indices_4, indices_3, indices_2, indices_1)\n        x = self.classifier_seg(x)\n        return x\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None):\n                    # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics\n                    nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')\n                    if m.bias is not None:\n                        nn.init.constant_(m.bias, 0)\n                else:\n                    print('Not initializing')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n    def __str__(self):\n        return 'SegNet_VGG'\n\n\nclass WSegNet_VGG(nn.Module):\n    def __init__(self, features, num_classes, init_weights = True, wavename = None):\n        super(WSegNet_VGG, self).__init__()\n        self.features = features[0]\n        self.decoders = features[1]\n        self.classifier_seg = nn.Sequential(\n            #nn.Conv2d(64, 64, kernel_size = 3, padding = 1),\n            #nn.ReLU(True),\n            nn.Conv2d(64, num_classes, kernel_size = 1, padding = 0),\n        )\n        if init_weights:\n            self._initialize_weights()\n\n    def forward(self, x):\n        xx = self.features(x)\n        x, [(LH1,HL1,HH1), (LH2,HL2,HH2,), (LH3,HL3,HH3,), (LH4,HL4,HH4,), (LH5,HL5,HH5,)] = xx\n        x = self.decoders(x, LH5,HL5,HH5, LH4,HL4,HH4, LH3,HL3,HH3, LH2,HL2,HH2, LH1,HL1,HH1)\n        x = self.classifier_seg(x)\n        return x\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                if(m.in_channels != m.out_channels or m.out_channels != m.groups or m.bias is not None):\n                    # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics\n                    nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')\n                    if m.bias is not None:\n                        nn.init.constant_(m.bias, 0)\n                else:\n                    print('Not initializing')\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n    def __str__(self):\n        return 'WSegNet_VGG'\n\n\ndef make_layers(cfg, batch_norm = False):\n    encoder = []\n    in_channels = 3\n    for v in cfg:\n        if v != 'M':\n            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n            if batch_norm:\n                encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]\n            else:\n                encoder += [conv2d, nn.ReLU(inplace=True)]\n            in_channels = v\n        elif v == 'M':\n            encoder += [nn.MaxPool2d(kernel_size = 2, stride = 2, return_indices = True)]\n    encoder = My_Sequential(*encoder)\n\n    decoder = []\n    cfg.reverse()\n    out_channels_final = 64\n    for index, v in enumerate(cfg):\n        if index != len(cfg) - 1:\n            out_channels = cfg[index + 1]\n        else:\n            out_channels = out_channels_final\n        if out_channels == 'M':\n            out_channels = cfg[index + 2]\n        if v == 'M':\n            decoder += [nn.MaxUnpool2d(kernel_size = 2, stride = 2)]\n        else:\n            conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1)\n            if batch_norm:\n                decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)]\n            else:\n                decoder += [conv2d, nn.ReLU(inplace = True)]\n    decoder = My_Sequential_re(*decoder)\n    return encoder, decoder\n\n\ndef make_w_layers(cfg, in_channels, batch_norm = False, wavename = 'haar'):\n    encoder = []\n    for v in cfg:\n        if v != 'M':\n            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n            if batch_norm:\n                encoder += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]\n            else:\n                encoder += [conv2d, nn.ReLU(inplace=True)]\n            in_channels = v\n        elif v == 'M':\n            encoder += [DWT_2D(wavename = wavename)]\n    encoder = My_Sequential(*encoder)\n\n    decoder = []\n    cfg.reverse()\n    out_channels_final = 64\n    for index, v in enumerate(cfg):\n        if index != len(cfg) - 1:\n            out_channels = cfg[index + 1]\n        else:\n            out_channels = out_channels_final\n        if out_channels == 'M':\n            out_channels = cfg[index + 2]\n        if v == 'M':\n            decoder += [IDWT_2D(wavename = wavename)]\n        else:\n            conv2d = nn.Conv2d(v, out_channels, kernel_size = 3, padding = 1)\n            if batch_norm:\n                decoder += [conv2d, nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True)]\n            else:\n                decoder += [conv2d, nn.ReLU(inplace = True)]\n    decoder = My_Sequential_re(*decoder)\n    return encoder, decoder\n\n\ncfg = {\n    'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],   # 11 layers\n    'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],   # 13 layers\n    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],   # 16 layers out_channels for encoder, input_channels for decoder\n    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],   # 19 layers\n}\n\ndef segnet_vgg11(pretrained = False, **kwargs):\n    \"\"\"VGG 11-layer model (configuration \"A\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['A']), **kwargs)\n    return model\n\n\ndef segnet_vgg11_bn(pretrained=False, **kwargs):\n    \"\"\"VGG 11-layer model (configuration \"A\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['A'], batch_norm = True), **kwargs)\n    return model\n\n\ndef segnet_vgg13(pretrained=False, **kwargs):\n    \"\"\"VGG 13-layer model (configuration \"B\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['B']), **kwargs)\n    return model\n\n\ndef segnet_vgg13_bn(pretrained=False, **kwargs):\n    \"\"\"VGG 13-layer model (configuration \"B\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)\n    return model\n\n\ndef segnet_vgg16(pretrained=False, **kwargs):\n    \"\"\"VGG 16-layer model (configuration \"D\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['D']), **kwargs)\n    return model\n\n\ndef segnet_vgg16_bn(pretrained=False, **kwargs):\n    \"\"\"VGG 16-layer model (configuration \"D\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)\n    return model\n\n\ndef segnet_vgg19(pretrained=False, **kwargs):\n    \"\"\"VGG 19-layer model (configuration \"E\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['E']), **kwargs)\n    return model\n\n\ndef segnet_vgg19_bn(pretrained=False, **kwargs):\n    \"\"\"VGG 19-layer model (configuration 'E') with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = SegNet_VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)\n    return model\n\n\"\"\"=================================================================================\"\"\"\n\ndef wsegnet_vgg11(pretrained = False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 11-layer model (configuration \"A\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['A'], wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg11_bn(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 11-layer model (configuration \"A\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['A'], batch_norm = True, wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg13(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 13-layer model (configuration \"B\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['B'], wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg13_bn(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 13-layer model (configuration \"B\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['B'], batch_norm=True, wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg16(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 16-layer model (configuration \"D\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['D'], wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg16_bn(in_channels, num_classes, pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 16-layer model (configuration \"D\") with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['D'], in_channels, batch_norm=True, wavename = wavename), num_classes, **kwargs)\n    return model\n\n\ndef wsegnet_vgg19(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 19-layer model (configuration \"E\")\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['E'], wavename = wavename), **kwargs)\n    return model\n\n\ndef wsegnet_vgg19_bn(pretrained=False, wavename = 'haar', **kwargs):\n    \"\"\"VGG 19-layer model (configuration 'E') with batch normalization\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    if pretrained:\n        kwargs['init_weights'] = False\n    model = WSegNet_VGG(make_w_layers(cfg['E'], batch_norm=True, wavename = wavename), **kwargs)\n    return model\n\n\n# if __name__ == '__main__':\n#     from loss.loss_function import segmentation_loss\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 128, 128).long()\n#     model = wsegnet_vgg16_bn(1, 5)\n#     model.train()\n#     input1 = torch.rand(2, 1, 128, 128)\n#     y = model(input1)\n#     loss_train = criterion(y, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(y.data.cpu().numpy().shape)\n#     print(loss_train)"
  },
  {
    "path": "models/networks_2d/wds.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch.distributions.uniform import Uniform\nimport numpy as np\n\nclass basic_block(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(basic_block, self).__init__()\n        self.block = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.ReLU(inplace=True))\n    def forward(self, x):\n        x = self.block(x)\n        return x\n\nclass WDS(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(WDS, self).__init__()\n\n        # branch1\n        self.b1_1 = basic_block(in_channels, 64)\n        self.b1_2 = basic_block(64, 64)\n        self.b1_3 = basic_block(64, 64)\n        self.b1_4 = basic_block(64, 64)\n        self.b1_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n        self.b1_6 = basic_block(64, 128)\n        self.b1_7 = basic_block(128, 128)\n        self.b1_8 = basic_block(128, 128)\n        self.b1_9 = basic_block(128, 128)\n        self.b1_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # branch2\n        self.b2_1 = basic_block(in_channels, 64)\n        self.b2_2 = basic_block(64, 64)\n        self.b2_3 = basic_block(64, 64)\n        self.b2_4 = basic_block(64, 64)\n        self.b2_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n        self.b2_6 = basic_block(64, 128)\n        self.b2_7 = basic_block(128, 128)\n        self.b2_8 = basic_block(128, 128)\n        self.b2_9 = basic_block(128, 128)\n        self.b2_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # branch3\n        self.b3_1 = basic_block(in_channels, 64)\n        self.b3_2 = basic_block(64, 64)\n        self.b3_3 = basic_block(64, 64)\n        self.b3_4 = basic_block(64, 64)\n        self.b3_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n        self.b3_6 = basic_block(64, 128)\n        self.b3_7 = basic_block(128, 128)\n        self.b3_8 = basic_block(128, 128)\n        self.b3_9 = basic_block(128, 128)\n        self.b3_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # branch4\n        self.b4_1 = basic_block(in_channels, 64)\n        self.b4_2 = basic_block(64, 64)\n        self.b4_3 = basic_block(64, 64)\n        self.b4_4 = basic_block(64, 64)\n        self.b4_5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n        self.b4_6 = basic_block(64, 128)\n        self.b4_7 = basic_block(128, 128)\n        self.b4_8 = basic_block(128, 128)\n        self.b4_9 = basic_block(128, 128)\n        self.b4_10 = nn.MaxPool2d(2, stride=2, ceil_mode=True)\n\n        # output\n        self.output_layer = nn.Sequential(\n            nn.Conv2d(128*4, 128, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1, bias=False),\n        )\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, LL, LH, HL, HH):\n\n        # H, W = 2*LL.shape[2], 2*LL.shape[3]\n        H, W = LL.shape[2], LL.shape[3]\n\n        LL = self.b1_1(LL)\n        LL = self.b1_2(LL)\n        LL = self.b1_3(LL)\n        LL = self.b1_4(LL)\n        LL = self.b1_5(LL)\n        LL = self.b1_6(LL)\n        LL = self.b1_7(LL)\n        LL = self.b1_8(LL)\n        LL = self.b1_9(LL)\n        LL = self.b1_10(LL)\n\n        LH = self.b2_1(LH)\n        LH = self.b2_2(LH)\n        LH = self.b2_3(LH)\n        LH = self.b2_4(LH)\n        LH = self.b2_5(LH)\n        LH = self.b2_6(LH)\n        LH = self.b2_7(LH)\n        LH = self.b2_8(LH)\n        LH = self.b2_9(LH)\n        LH = self.b2_10(LH)\n\n        HL = self.b3_1(HL)\n        HL = self.b3_2(HL)\n        HL = self.b3_3(HL)\n        HL = self.b3_4(HL)\n        HL = self.b3_5(HL)\n        HL = self.b3_6(HL)\n        HL = self.b3_7(HL)\n        HL = self.b3_8(HL)\n        HL = self.b3_9(HL)\n        HL = self.b3_10(HL)\n\n        HH = self.b4_1(HH)\n        HH = self.b4_2(HH)\n        HH = self.b4_3(HH)\n        HH = self.b4_4(HH)\n        HH = self.b4_5(HH)\n        HH = self.b4_6(HH)\n        HH = self.b4_7(HH)\n        HH = self.b4_8(HH)\n        HH = self.b4_9(HH)\n        HH = self.b4_10(HH)\n\n        x = torch.cat((LL, LH, HL, HH), dim=1)\n        x = self.output_layer(x)\n        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)\n        return x\n\nif __name__ == '__main__':\n    from loss.loss_function import segmentation_loss\n    criterion = segmentation_loss('dice', False)\n    mask = torch.ones(2, 128, 128).long()\n    model = WDS(1, 5)\n    model.train()\n    input1 = torch.rand(2, 1, 128, 128)\n    y = model(input1, input1, input1, input1)\n    loss_train = criterion(y, mask)\n    loss_train.backward()\n    # print(output)\n    print(y.data.cpu().numpy().shape)\n    print(loss_train)\n"
  },
  {
    "path": "models/networks_2d/xnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch.distributions.uniform import Uniform\nimport numpy as np\nBatchNorm2d = nn.BatchNorm2d\nrelu_inplace = True\n\nBN_MOMENTUM = 0.1\n# BN_MOMENTUM = 0.01\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)\n\nclass up_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(up_conv, self).__init__()\n        self.up = nn.Sequential(\n            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            BatchNorm2d(ch_out, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=relu_inplace)\n        )\n\n    def forward(self, x):\n        x = self.up(x)\n        return x\n\nclass down_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(down_conv, self).__init__()\n        self.down = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False),\n            BatchNorm2d(ch_out, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=relu_inplace)\n        )\n    def forward(self, x):\n        x = self.down(x)\n        return x\n\nclass same_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(same_conv, self).__init__()\n        self.same = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),\n            BatchNorm2d(ch_out, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=relu_inplace))\n    def forward(self, x):\n        x = self.same(x)\n        return x\n\nclass transition_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(transition_conv, self).__init__()\n        self.transition = nn.Sequential(\n            nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False),\n            BatchNorm2d(ch_out, momentum=BN_MOMENTUM),\n            nn.ReLU(inplace=relu_inplace))\n    def forward(self, x):\n        x = self.transition(x)\n        return x\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM)\n        self.relu = nn.ReLU(inplace=relu_inplace)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        # out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out = self.bn2(out) + identity\n        out = self.relu(out)\n\n        return out\n\nclass DoubleBasicBlock(nn.Module):\n    def __init__(self, inplanes, planes, downsample=None):\n        super(DoubleBasicBlock, self).__init__()\n\n        self.DBB = nn.Sequential(\n            BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample),\n            BasicBlock(inplanes=planes, planes=planes)\n        )\n\n    def forward(self, x):\n        out = self.DBB(x)\n        return out\n\n\nclass XNet(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b2_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4_1 = self.b1_3_2_down(x1_3)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4_1 = self.b2_3_2_down(x2_3)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        # decode\n        # branch1\n        x1_3 = torch.cat((x1_3, x1_4_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_3 = torch.cat((x2_3, x2_4_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\n\nclass XNet_1_1_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_1_1_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_4_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_4_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4 = self.b1_3_2_down(x1_3)\n        x1_4 = self.b1_4_1(x1_4)\n\n        x1_5_1 = self.b1_4_2_down(x1_4)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4 = self.b2_3_2_down(x2_3)\n        x2_4 = self.b2_4_1(x2_4)\n\n        x2_5_1 = self.b2_4_2_down(x2_4)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        # decode\n        # branch1\n        x1_4 = torch.cat((x1_4, x1_5_3), dim=1)\n        x1_4 = self.b1_4_3(x1_4)\n        x1_4 = self.b1_4_4_up(x1_4)\n\n        x1_3 = torch.cat((x1_3, x1_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_4 = torch.cat((x2_4, x2_5_3), dim=1)\n        x2_4 = self.b2_4_3(x2_4)\n        x2_4 = self.b2_4_4_up(x2_4)\n\n        x2_3 = torch.cat((x2_3, x2_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\nclass XNet_1_2_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_1_2_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_4_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c, l4c)\n        self.b2_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4 = self.b1_3_2_down(x1_3)\n        x1_4 = self.b1_4_1(x1_4)\n\n        x1_5_1 = self.b1_4_2_down(x1_4)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4_1 = self.b2_3_2_down(x2_3)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        # decode\n        # branch1\n        x1_4 = torch.cat((x1_4, x1_5_3), dim=1)\n        x1_4 = self.b1_4_3(x1_4)\n        x1_4 = self.b1_4_4_up(x1_4)\n\n        x1_3 = torch.cat((x1_3, x1_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_3 = torch.cat((x2_3, x2_4_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\n\nclass XNet_2_1_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_2_1_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_3 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_4_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4_1 = self.b1_3_2_down(x1_3)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4 = self.b2_3_2_down(x2_3)\n        x2_4 = self.b2_4_1(x2_4)\n\n        x2_5_1 = self.b2_4_2_down(x2_4)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        # decode\n        # branch1\n        x1_3 = torch.cat((x1_3, x1_4_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_4 = torch.cat((x2_4, x2_5_3), dim=1)\n        x2_4 = self.b2_4_3(x2_4)\n        x2_4 = self.b2_4_4_up(x2_4)\n\n        x2_3 = torch.cat((x2_3, x2_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\n\nclass XNet_2_3_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_2_3_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_3_up = up_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_up_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_2 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_3_down = down_conv(l3c, l3c)\n        self.b2_3_3_down_down = down_conv(l3c, l3c)\n        self.b2_3_3_same = same_conv(l3c, l3c)\n        self.b2_3_4_transition = transition_conv(l3c+l5c+l4c, l3c)\n        self.b2_3_5 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_7_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b2_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4_1 = self.b1_3_2_down(x1_3)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n        x1_4_3_up = self.b1_4_3_up(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3_1 = self.b2_2_2_down(x2_2)\n        x2_3_1 = self.b2_3_1(x2_3_1)\n        x2_3_2 = self.b2_3_2(x2_3_1)\n        x2_3_3_down = self.b2_3_3_down(x2_3_2)\n        x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down)\n        x2_3_3_same = self.b2_3_3_same(x2_3_2)\n\n        x2_4_1 = self.b2_3_2_down(x2_3_1)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_4_3_down, x1_5_2_same), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        x2_3_4 = torch.cat((x2_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1)\n        x2_3_4 = self.b2_3_4_transition(x2_3_4)\n        x2_3_4 = self.b2_3_5(x2_3_4)\n        x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1)\n        x2_3_4 = self.b2_3_6(x2_3_4)\n        x2_3_4 = self.b2_3_7_up(x2_3_4)\n\n        # decode\n        # branch1\n        x1_3 = torch.cat((x1_3, x1_4_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_2 = torch.cat((x2_2, x2_3_4), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\n\nclass XNet_3_2_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_3_2_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_2 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_3_down = down_conv(l3c, l3c)\n        self.b1_3_3_down_down = down_conv(l3c, l3c)\n        self.b1_3_3_same = same_conv(l3c, l3c)\n        self.b1_3_4_transition = transition_conv(l3c+l5c+l4c, l3c)\n        self.b1_3_5 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_7_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = DoubleBasicBlock(l3c + l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c + l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_3_up = up_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)\n        self.b2_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_up_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3_1 = self.b1_2_2_down(x1_2)\n        x1_3_1 = self.b1_3_1(x1_3_1)\n        x1_3_2 = self.b1_3_2(x1_3_1)\n        x1_3_3_down = self.b1_3_3_down(x1_3_2)\n        x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down)\n        x1_3_3_same = self.b1_3_3_same(x1_3_2)\n\n        x1_4_1 = self.b1_3_2_down(x1_3_1)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4_1 = self.b2_3_2_down(x2_3)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n        x2_4_3_up = self.b2_4_3_up(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_4_3_down, x2_5_2_same), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n\n        x1_3_4 = torch.cat((x1_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1)\n        x1_3_4 = self.b1_3_4_transition(x1_3_4)\n        x1_3_4 = self.b1_3_5(x1_3_4)\n        x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1)\n        x1_3_4 = self.b1_3_6(x1_3_4)\n        x1_3_4 = self.b1_3_7_up(x1_3_4)\n\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        # decode\n        # branch1\n        x1_2 = torch.cat((x1_2, x1_3_4), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_3 = torch.cat((x2_3, x2_4_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\n\n\nclass XNet_3_3_m(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_3_3_m, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_2 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_3_down = down_conv(l3c, l3c)\n        self.b1_3_3_down_down = down_conv(l3c, l3c)\n        self.b1_3_3_same = same_conv(l3c, l3c)\n        self.b1_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c)\n        self.b1_3_5 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_7_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_3_up = up_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_up_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_2 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_3_down = down_conv(l3c, l3c)\n        self.b2_3_3_down_down = down_conv(l3c, l3c)\n        self.b2_3_3_same = same_conv(l3c, l3c)\n        self.b2_3_4_transition = transition_conv(l3c+l5c+l4c+l3c, l3c)\n        self.b2_3_5 = DoubleBasicBlock(l3c, l3c)\n        self.b2_3_6 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_7_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_3_up = up_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c+l4c+l3c, l4c)\n        self.b2_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b2_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_up_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c+l3c, l5c)\n        self.b2_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3_1 = self.b1_2_2_down(x1_2)\n        x1_3_1 = self.b1_3_1(x1_3_1)\n        x1_3_2 = self.b1_3_2(x1_3_1)\n        x1_3_3_down = self.b1_3_3_down(x1_3_2)\n        x1_3_3_down_down = self.b1_3_3_down_down(x1_3_3_down)\n        x1_3_3_same = self.b1_3_3_same(x1_3_2)\n\n        x1_4_1 = self.b1_3_2_down(x1_3_1)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n        x1_4_3_up = self.b1_4_3_up(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_up_up = self.b1_5_2_up_up(x1_5_2_up)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3_1 = self.b2_2_2_down(x2_2)\n        x2_3_1 = self.b2_3_1(x2_3_1)\n        x2_3_2 = self.b2_3_2(x2_3_1)\n        x2_3_3_down = self.b2_3_3_down(x2_3_2)\n        x2_3_3_down_down = self.b2_3_3_down_down(x2_3_3_down)\n        x2_3_3_same = self.b2_3_3_same(x2_3_2)\n\n        x2_4_1 = self.b2_3_2_down(x2_3_1)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n        x2_4_3_up = self.b2_4_3_up(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_up_up = self.b2_5_2_up_up(x2_5_2_up)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_3_3_down_down, x2_4_3_down, x2_5_2_same), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_3_3_down, x2_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n\n        x1_3_4 = torch.cat((x1_3_3_same, x2_3_3_same, x2_4_3_up, x2_5_2_up_up), dim=1)\n        x1_3_4 = self.b1_3_4_transition(x1_3_4)\n        x1_3_4 = self.b1_3_5(x1_3_4)\n        x1_3_4 = torch.cat((x1_3_4, x1_4_4), dim=1)\n        x1_3_4 = self.b1_3_6(x1_3_4)\n        x1_3_4 = self.b1_3_7_up(x1_3_4)\n\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_3_3_down_down, x1_4_3_down, x1_5_2_same), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_3_3_down, x1_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        x2_3_4 = torch.cat((x2_3_3_same, x1_3_3_same, x1_4_3_up, x1_5_2_up_up), dim=1)\n        x2_3_4 = self.b2_3_4_transition(x2_3_4)\n        x2_3_4 = self.b2_3_5(x2_3_4)\n        x2_3_4 = torch.cat((x2_3_4, x2_4_4), dim=1)\n        x2_3_4 = self.b2_3_6(x2_3_4)\n        x2_3_4 = self.b2_3_7_up(x2_3_4)\n\n        # decode\n        # branch1\n        x1_2 = torch.cat((x1_2, x1_3_4), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_2 = torch.cat((x2_2, x2_3_4), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\nclass XNet_sb(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet_sb, self).__init__()\n\n        l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024\n\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = DoubleBasicBlock(l1c+l1c, l1c, nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm2d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv2d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = DoubleBasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = DoubleBasicBlock(l2c+l2c, l2c, nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm2d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = DoubleBasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = DoubleBasicBlock(l3c+l3c, l3c, nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm2d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = DoubleBasicBlock(l4c, l4c)\n        # self.b1_4_3_down = down_conv(l4c, l4c)\n        # self.b1_4_3_same = same_conv(l4c, l4c)\n        # self.b1_4_4_transition = transition_conv(l4c, l4c)\n        self.b1_4_5 = DoubleBasicBlock(l4c, l4c)\n        self.b1_4_6 = DoubleBasicBlock(l4c+l4c, l4c, nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm2d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = DoubleBasicBlock(l5c, l5c)\n        # self.b1_5_2_up = up_conv(l5c, l5c)\n        # self.b1_5_2_same = same_conv(l5c, l5c)\n        # self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b1_5_4 = DoubleBasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4_1 = self.b1_3_2_down(x1_3)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_2 = self.b1_4_5(x1_4_2)\n        # x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        # x1_4_3_same = self.b1_4_3_same(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_1 = self.b1_5_4(x1_5_1)\n        x1_5_1 = self.b1_5_5_up(x1_5_1)\n\n        # x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        # x1_5_2_same = self.b1_5_2_same(x1_5_1)\n\n        # decode\n        # branch1\n        x1_4_2 = torch.cat((x1_4_2, x1_5_1), dim=1)\n        x1_4_2 = self.b1_4_6(x1_4_2)\n        x1_4_2 = self.b1_4_7_up(x1_4_2)\n\n        x1_3 = torch.cat((x1_3, x1_4_2), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n\n        return x1_1\n\n\n# if __name__ == '__main__':\n#     model = XNet(1, 10)\n    # total = sum([param.nelement() for param in model.parameters()])\n    # from thop import profile, clever_format\n    #\n    # input = torch.randn(1, 1, 128, 128)\n    # flops, params = profile(model, inputs=(input, input, ))\n    # macs, params = clever_format([flops, params], \"%.3f\")\n    # print(macs)\n    # print(params)\n    # print(total)\n    # model.eval()\n    # input1 = torch.rand(2,3,256,256)\n    # input2 = torch.rand(2,1,256,256)\n    # x1_1, x2_1 = model(input1, input2)\n    # output1 = x1_1.data.cpu().numpy()\n    # output2 = x2_1.data.cpu().numpy()\n    # # print(output)\n    # print(output1.shape)\n    # print(output2.shape)\n"
  },
  {
    "path": "models/networks_3d/__init__.py",
    "content": ""
  },
  {
    "path": "models/networks_3d/conresnet.py",
    "content": "import torch.nn as nn\nfrom torch.nn import functional as F\nimport torch\nimport numpy as np\nfrom torch.nn import init\n# from loss.loss_function import segmentation_loss\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass Conv3d(nn.Conv3d):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False):\n        super(Conv3d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)\n\n    def forward(self, x):\n        weight = self.weight\n        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)\n        weight = weight - weight_mean\n        std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)\n        weight = weight / std.expand_as(weight)\n        return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n\ndef conv3x3x3(in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1,1,1), dilation=(1,1,1), bias=False,\n              weight_std=False):\n    \"3x3x3 convolution with padding\"\n    if weight_std:\n        return Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,\n                      bias=bias)\n    else:\n        return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,\n                         dilation=dilation, bias=bias)\n\n\nclass ConResAtt(nn.Module):\n    def __init__(self, in_channels, in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),\n                 dilation=(1, 1, 1), bias=False, weight_std=False, first_layer=False):\n        super(ConResAtt, self).__init__()\n        self.weight_std = weight_std\n        self.stride = stride\n        self.in_planes = in_planes\n        self.out_planes = out_planes\n        self.first_layer = first_layer\n\n        self.relu = nn.ReLU(inplace=True)\n\n        self.gn_seg = nn.GroupNorm(8, in_planes)\n        self.conv_seg = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),\n                               stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),\n                               dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n\n        self.gn_res = nn.GroupNorm(8, out_planes)\n        self.conv_res = conv3x3x3(out_planes, out_planes, kernel_size=(1,1,1),\n                               stride=(1, 1, 1), padding=(0,0,0),\n                               dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n\n        self.gn_res1 = nn.GroupNorm(8, out_planes)\n        self.conv_res1 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),\n                                stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]),\n                                dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n        self.gn_res2 = nn.GroupNorm(8, out_planes)\n        self.conv_res2 = conv3x3x3(out_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),\n                                stride=(1, 1, 1), padding=(padding[0], padding[1], padding[2]),\n                                dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n\n        self.gn_mp = nn.GroupNorm(8, in_planes)\n        self.conv_mp_first = conv3x3x3(in_channels, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),\n                              stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),\n                              dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n        self.conv_mp = conv3x3x3(in_planes, out_planes, kernel_size=(kernel_size[0], kernel_size[1], kernel_size[2]),\n                               stride=(stride[0], stride[1], stride[2]), padding=(padding[0], padding[1], padding[2]),\n                               dilation=(dilation[0], dilation[1], dilation[2]), bias=bias, weight_std=self.weight_std)\n\n    def _res(self, x):  # bs, channel, D, W, H\n\n        bs, channel, depth, heigt, width = x.shape\n        # x_copy = torch.zeros_like(x).cuda()\n        x_copy = torch.zeros_like(x)\n        x_copy[:, :, 1:, :, :] = x[:, :, 0: depth - 1, :, :]\n        res = x - x_copy\n        res[:, :, 0, :, :] = 0\n        res = torch.abs(res)\n        return res\n\n    def forward(self, input):\n        x1, x2 = input\n        if self.first_layer:\n            x1 = self.gn_seg(x1)\n            x1 = self.relu(x1)\n            x1 = self.conv_seg(x1)\n\n            res = torch.sigmoid(x1)\n            res = self._res(res)\n            res = self.conv_res(res)\n\n            x2 = self.conv_mp_first(x2)\n            x2 = x2 + res\n\n        else:\n            x1 = self.gn_seg(x1)\n            x1 = self.relu(x1)\n            x1 = self.conv_seg(x1)\n\n            res = torch.sigmoid(x1)\n            res = self._res(res)\n            res = self.conv_res(res)\n\n\n            if self.in_planes != self.out_planes:\n                x2 = self.gn_mp(x2)\n                x2 = self.relu(x2)\n                x2 = self.conv_mp(x2)\n\n            x2 = x2 + res\n\n        x2 = self.gn_res1(x2)\n        x2 = self.relu(x2)\n        x2 = self.conv_res1(x2)\n\n        x1 = x1*(1 + torch.sigmoid(x2))\n\n        return [x1, x2]\n\n\nclass NoBottleneck(nn.Module):\n    def __init__(self, inplanes, planes, stride=(1, 1, 1), dilation=(1, 1, 1), downsample=None, fist_dilation=1,\n                 multi_grid=1, weight_std=False):\n        super(NoBottleneck, self).__init__()\n        self.weight_std = weight_std\n        self.relu = nn.ReLU(inplace=True)\n\n        self.gn1 = nn.GroupNorm(8, inplanes)\n        self.conv1 = conv3x3x3(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=dilation * multi_grid,\n                                 dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)\n\n        self.gn2 = nn.GroupNorm(8, planes)\n        self.conv2 = conv3x3x3(planes, planes, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=dilation * multi_grid,\n                                 dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)\n\n        self.downsample = downsample\n        self.dilation = dilation\n        self.stride = stride\n\n    def forward(self, x):\n        skip = x\n\n        seg = self.gn1(x)\n        seg = self.relu(seg)\n        seg = self.conv1(seg)\n\n        seg = self.gn2(seg)\n        seg = self.relu(seg)\n        seg = self.conv2(seg)\n\n        if self.downsample is not None:\n            skip = self.downsample(x)\n\n        seg = seg + skip\n        return seg\n\n\nclass ConResNet(nn.Module):\n    def __init__(self, in_channels, num_classes, shape, block, layers, weight_std=False):\n        self.shape = shape\n        self.weight_std = weight_std\n        super(ConResNet, self).__init__()\n\n        self.conv_4_32 = nn.Sequential(\n            conv3x3x3(in_channels, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), weight_std=self.weight_std))\n\n        self.conv_32_64 = nn.Sequential(\n            nn.GroupNorm(8, 32),\n            nn.ReLU(inplace=True),\n            conv3x3x3(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))\n\n        self.conv_64_128 = nn.Sequential(\n            nn.GroupNorm(8, 64),\n            nn.ReLU(inplace=True),\n            conv3x3x3(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))\n\n        self.conv_128_256 = nn.Sequential(\n            nn.GroupNorm(8, 128),\n            nn.ReLU(inplace=True),\n            conv3x3x3(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), weight_std=self.weight_std))\n\n        self.layer0 = self._make_layer(block, 32, 32, layers[0], stride=(1, 1, 1))\n        self.layer1 = self._make_layer(block, 64, 64, layers[1], stride=(1, 1, 1))\n        self.layer2 = self._make_layer(block, 128, 128, layers[2], stride=(1, 1, 1))\n        self.layer3 = self._make_layer(block, 256, 256, layers[3], stride=(1, 1, 1))\n        self.layer4 = self._make_layer(block, 256, 256, layers[4], stride=(1, 1, 1), dilation=(2,2,2))\n\n        self.fusionConv = nn.Sequential(\n            nn.GroupNorm(8, 256),\n            nn.ReLU(inplace=True),\n            nn.Dropout3d(0.1),\n            conv3x3x3(256, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1), weight_std=self.weight_std)\n        )\n\n        self.seg_x4 = nn.Sequential(\n            ConResAtt(in_channels, 128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std, first_layer=True))\n        self.seg_x2 = nn.Sequential(\n            ConResAtt(in_channels, 64, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std))\n        self.seg_x1 = nn.Sequential(\n            ConResAtt(in_channels, 32, 32, kernel_size=(3, 3, 3), padding=(1, 1, 1), weight_std=self.weight_std))\n\n        self.seg_cls = nn.Sequential(\n            nn.Conv3d(32, num_classes, kernel_size=1)\n        )\n        self.res_cls = nn.Sequential(\n            nn.Conv3d(32, num_classes, kernel_size=1)\n        )\n        self.resx2_cls = nn.Sequential(\n            nn.Conv3d(32, num_classes, kernel_size=1)\n        )\n        self.resx4_cls = nn.Sequential(\n            nn.Conv3d(64, num_classes, kernel_size=1)\n        )\n\n    def _make_layer(self, block, inplanes, outplanes, blocks, stride=(1, 1, 1), dilation=(1, 1, 1), multi_grid=1):\n        downsample = None\n        if stride[0] != 1 or stride[1] != 1 or stride[2] != 1 or inplanes != outplanes:\n            downsample = nn.Sequential(\n                nn.GroupNorm(8, inplanes),\n                nn.ReLU(inplace=True),\n                conv3x3x3(inplanes, outplanes, kernel_size=(1, 1, 1), stride=stride, padding=(0, 0, 0),\n                            weight_std=self.weight_std)\n            )\n\n        layers = []\n        generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1\n        layers.append(block(inplanes, outplanes, stride, dilation=dilation, downsample=downsample,\n                            multi_grid=generate_multi_grid(0, multi_grid), weight_std=self.weight_std))\n        for i in range(1, blocks):\n            layers.append(\n                block(inplanes, outplanes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid),\n                      weight_std=self.weight_std))\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x, x_res):\n\n        ## encoder\n        x = self.conv_4_32(x)\n        x = self.layer0(x)\n        skip1 = x\n\n        x = self.conv_32_64(x)\n        x = self.layer1(x)\n        skip2 = x\n\n        x = self.conv_64_128(x)\n        x = self.layer2(x)\n        skip3 = x\n\n        x = self.conv_128_256(x)\n        x = self.layer3(x)\n\n        x = self.layer4(x)\n\n        x = self.fusionConv(x)\n\n        ## decoder\n        res_x4 = F.interpolate(x_res, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True)\n        seg_x4 = F.interpolate(x, size=(int(self.shape[0] / 4), int(self.shape[1] / 4), int(self.shape[2] / 4)), mode='trilinear', align_corners=True)\n        seg_x4 = seg_x4 + skip3\n        seg_x4, res_x4 = self.seg_x4([seg_x4, res_x4])\n\n        res_x2 = F.interpolate(res_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True)\n        seg_x2 = F.interpolate(seg_x4, size=(int(self.shape[0] / 2), int(self.shape[1] / 2), int(self.shape[2] / 2)), mode='trilinear', align_corners=True)\n        seg_x2 = seg_x2 + skip2\n        seg_x2, res_x2 = self.seg_x2([seg_x2, res_x2])\n\n        res_x1 = F.interpolate(res_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)\n        seg_x1 = F.interpolate(seg_x2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)\n        seg_x1 = seg_x1 + skip1\n        seg_x1, res_x1 = self.seg_x1([seg_x1, res_x1])\n\n        seg = self.seg_cls(seg_x1)\n        res = self.res_cls(res_x1)\n        resx2 = self.resx2_cls(res_x2)\n        resx4 = self.resx4_cls(res_x4)\n\n        resx2 = F.interpolate(resx2, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)\n        resx4 = F.interpolate(resx4, size=(int(self.shape[0] / 1), int(self.shape[1] / 1), int(self.shape[2] / 1)), mode='trilinear', align_corners=True)\n\n        return [seg, res, resx2, resx4]\n\n\ndef conresnet(in_channels, num_classes, **kwargs):\n    model = ConResNet(in_channels, num_classes, kwargs['img_shape'], NoBottleneck, [1, 2, 2, 2, 2])\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(5, 64, 64, 64).long()\n#     model = conresnet(1, 10, img_shape=(64, 64, 64))\n#     model.train()\n#     output = model(torch.rand(5, 1, 64, 64, 64), torch.rand(5, 1, 64, 64, 64))\n#\n#     loss_train_1 = criterion(output[0], mask)\n#     loss_train_2 = criterion(output[1], mask)\n#     loss_train_3 = criterion(output[2], mask)\n#     loss_train_4 = criterion(output[3], mask)\n#\n#     loss_train_1.backward()\n#\n#     print(output[0].data.cpu().numpy().shape)\n#     print(output[1].data.cpu().numpy().shape)\n#     print(output[2].data.cpu().numpy().shape)\n#     print(output[3].data.cpu().numpy().shape)\n#     print(loss_train_1)\n#     print(loss_train_2)\n#     print(loss_train_3)\n#     print(loss_train_4)\n\n\n"
  },
  {
    "path": "models/networks_3d/cotr.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.nn.init import xavier_uniform_, constant_, normal_\nimport copy\nimport math\n# from loss.loss_function import segmentation_loss\n\nclass PositionEmbeddingSine(nn.Module):\n    \"\"\"\n    This is a more standard version of the position embedding, very similar to the one\n    used by the Attention is all you need paper, generalized to work on images.\n    \"\"\"\n\n    def __init__(self, num_pos_feats=[64, 64, 64], temperature=10000, normalize=False, scale=None):\n        super().__init__()\n        self.num_pos_feats = num_pos_feats\n        self.temperature = temperature\n        self.normalize = normalize\n        if scale is not None and normalize is False:\n            raise ValueError(\"normalize should be True if scale is passed\")\n        if scale is None:\n            scale = 2 * math.pi\n        self.scale = scale\n\n    def forward(self, x):\n        bs, c, d, h, w = x.shape\n        mask = torch.zeros(bs, d, h, w, dtype=torch.bool).cuda()\n        # mask = torch.zeros(bs, d, h, w, dtype=torch.bool)\n        assert mask is not None\n        not_mask = ~mask\n        d_embed = not_mask.cumsum(1, dtype=torch.float32)\n        y_embed = not_mask.cumsum(2, dtype=torch.float32)\n        x_embed = not_mask.cumsum(3, dtype=torch.float32)\n        if self.normalize:\n            eps = 1e-6\n            d_embed = (d_embed - 0.5) / (d_embed[:, -1:, :, :] + eps) * self.scale\n            y_embed = (y_embed - 0.5) / (y_embed[:, :, -1:, :] + eps) * self.scale\n            x_embed = (x_embed - 0.5) / (x_embed[:, :, :, -1:] + eps) * self.scale\n\n        dim_tx = torch.arange(self.num_pos_feats[0], dtype=torch.float32, device=x.device)\n        dim_tx = self.temperature ** (3 * (dim_tx // 3) / self.num_pos_feats[0])\n\n        dim_ty = torch.arange(self.num_pos_feats[1], dtype=torch.float32, device=x.device)\n        dim_ty = self.temperature ** (3 * (dim_ty // 3) / self.num_pos_feats[1])\n\n        dim_td = torch.arange(self.num_pos_feats[2], dtype=torch.float32, device=x.device)\n        dim_td = self.temperature ** (3 * (dim_td // 3) / self.num_pos_feats[2])\n\n        pos_x = x_embed[:, :, :, :, None] / dim_tx\n        pos_y = y_embed[:, :, :, :, None] / dim_ty\n        pos_d = d_embed[:, :, :, :, None] / dim_td\n\n        pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n        pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n        pos_d = torch.stack((pos_d[:, :, :, :, 0::2].sin(), pos_d[:, :, :, :, 1::2].cos()), dim=5).flatten(4)\n\n        pos = torch.cat((pos_d, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3)\n        return pos\n\n\ndef build_position_encoding(mode, hidden_dim):\n    N_steps = hidden_dim // 3\n    if (hidden_dim % 3) != 0:\n        N_steps = [N_steps, N_steps, N_steps + hidden_dim % 3]\n    else:\n        N_steps = [N_steps, N_steps, N_steps]\n\n    if mode in ('v2', 'sine'):\n        position_embedding = PositionEmbeddingSine(num_pos_feats=N_steps, normalize=True)\n    else:\n        raise ValueError(f\"not supported {mode}\")\n\n    return position_embedding\n\ndef ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights):\n    N_, S_, M_, D_ = value.shape\n    _, Lq_, M_, L_, P_, _ = sampling_locations.shape\n    value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1)\n    sampling_grids = 2 * sampling_locations - 1\n    # sampling_grids = 3 * sampling_locations - 1\n    sampling_value_list = []\n    for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes):\n        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_)\n        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:]\n        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_.to(dtype=value_l_.dtype), mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0]\n        sampling_value_list.append(sampling_value_l_)\n    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)\n    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)\n    return output.transpose(1, 2).contiguous()\n\nclass MSDeformAttn(nn.Module):\n    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):\n        \"\"\"\n        Multi-Scale Deformable Attention Module\n        :param d_model      hidden dimension\n        :param n_levels     number of feature levels\n        :param n_heads      number of attention heads\n        :param n_points     number of sampling points per attention head per feature level\n        \"\"\"\n        super().__init__()\n        if d_model % n_heads != 0:\n            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))\n        _d_per_head = d_model // n_heads\n\n        self.im2col_step = 64\n\n        self.d_model = d_model\n        self.n_levels = n_levels\n        self.n_heads = n_heads\n        self.n_points = n_points\n\n        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)\n        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)\n        self.value_proj = nn.Linear(d_model, d_model)\n        self.output_proj = nn.Linear(d_model, d_model)\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        constant_(self.sampling_offsets.weight.data, 0.)\n        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)\n        grid_init = torch.stack([thetas.cos(), thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin()], -1)\n        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)\n        for i in range(self.n_points):\n            grid_init[:, :, i, :] *= i + 1\n        with torch.no_grad():\n            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))\n        constant_(self.attention_weights.weight.data, 0.)\n        constant_(self.attention_weights.bias.data, 0.)\n        xavier_uniform_(self.value_proj.weight.data)\n        constant_(self.value_proj.bias.data, 0.)\n        xavier_uniform_(self.output_proj.weight.data)\n        constant_(self.output_proj.bias.data, 0.)\n\n    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):\n        \"\"\"\n        :param query                       (N, Length_{query}, C)\n        :param reference_points            (N, Length_{query}, n_levels, 3)\n        :param input_flatten               (N, \\sum_{l=0}^{L-1} D_l \\cdot H_l \\cdot W_l, C)\n        :param input_spatial_shapes        (n_levels, 3), [(D_0, H_0, W_0), (D_1, H_1, W_1), ..., (D_{L-1}, H_{L-1}, W_{L-1})]\n        :param input_level_start_index     (n_levels, ), [0, D_0*H_0*W_0, D_0*H_0*W_0+D_1*H_1*W_1, D_0*H_0*W_0+D_1*H_1*W_1+D_2*H_2*W_2, ..., D_0*H_0*W_0+D_1*H_1*W_1+...+D_{L-1}*H_{L-1}*W_{L-1}]\n        :param input_padding_mask          (N, \\sum_{l=0}^{L-1} D_l \\cdot H_l \\cdot W_l), True for padding elements, False for non-padding elements\n        :return output                     (N, Length_{query}, C)\n        \"\"\"\n        N, Len_q, _ = query.shape\n        N, Len_in, _ = input_flatten.shape\n        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in\n\n        value = self.value_proj(input_flatten)\n        if input_padding_mask is not None:\n            value = value.masked_fill(input_padding_mask[..., None], float(0))\n        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)\n        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)\n        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)\n        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)\n\n        if reference_points.shape[-1] == 3:\n            offset_normalizer = torch.stack([input_spatial_shapes[..., 0], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1)\n            sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]\n\n        output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, attention_weights)\n\n        output = self.output_proj(output)\n        return output\n\ndef _get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\n\ndef _get_activation_fn(activation):\n    \"\"\"Return an activation function given a string\"\"\"\n    if activation == \"relu\":\n        return F.relu\n    if activation == \"gelu\":\n        return F.gelu\n    if activation == \"glu\":\n        return F.glu\n    raise RuntimeError(F\"activation should be relu/gelu, not {activation}.\")\n\nclass DeformableTransformerEncoderLayer(nn.Module):\n    def __init__(self,\n                 d_model=256, d_ffn=1024,\n                 dropout=0.1, activation=\"relu\",\n                 n_levels=4, n_heads=8, n_points=4):\n        super().__init__()\n\n        # self attention\n        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n        self.dropout1 = nn.Dropout(dropout)\n        self.norm1 = nn.LayerNorm(d_model)\n\n        # ffn\n        self.linear1 = nn.Linear(d_model, d_ffn)\n        self.activation = _get_activation_fn(activation)\n        self.dropout2 = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(d_ffn, d_model)\n        self.dropout3 = nn.Dropout(dropout)\n        self.norm2 = nn.LayerNorm(d_model)\n\n    @staticmethod\n    def with_pos_embed(tensor, pos):\n        return tensor if pos is None else tensor + pos\n\n    def forward_ffn(self, src):\n        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))\n        src = src + self.dropout3(src2)\n        src = self.norm2(src)\n        return src\n\n    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):\n        # self attention\n        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)\n        src = src + self.dropout1(src2)\n        src = self.norm1(src)\n\n        # ffn\n        src = self.forward_ffn(src)\n\n        return src\n\n\nclass DeformableTransformerEncoder(nn.Module):\n    def __init__(self, encoder_layer, num_layers):\n        super().__init__()\n        self.layers = _get_clones(encoder_layer, num_layers)\n        self.num_layers = num_layers\n\n    @staticmethod\n    def get_reference_points(spatial_shapes, valid_ratios, device):\n        reference_points_list = []\n        for lvl, (D_, H_, W_) in enumerate(spatial_shapes):\n\n            ref_d, ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, D_ - 0.5, D_, dtype=torch.float32, device=device),\n                                                 torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),\n                                                 torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))\n\n            ref_d = ref_d.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * D_)\n            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 2] * H_)\n            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * W_)\n\n            ref = torch.stack((ref_d, ref_x, ref_y), -1)   # D W H\n            reference_points_list.append(ref)\n        reference_points = torch.cat(reference_points_list, 1)\n        reference_points = reference_points[:, :, None] * valid_ratios[:, None]\n        return reference_points\n\n    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):\n        output = src\n        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)\n        for _, layer in enumerate(self.layers):\n            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)\n\n        return output\n\n\nclass DeformableTransformer(nn.Module):\n    def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation=\"relu\", num_feature_levels=4, enc_n_points=4):\n        super().__init__()\n\n        self.d_model = d_model\n        self.nhead = nhead\n\n        encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points)\n        self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)\n\n        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))\n\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n        for m in self.modules():\n            if isinstance(m, MSDeformAttn):\n                m._reset_parameters()\n        normal_(self.level_embed)\n\n    def get_valid_ratio(self, mask):\n        _, D, H, W = mask.shape\n        valid_D = torch.sum(~mask[:, :, 0, 0], 1)\n        valid_H = torch.sum(~mask[:, 0, :, 0], 1)\n        valid_W = torch.sum(~mask[:, 0, 0, :], 1)\n\n        valid_ratio_d = valid_D.float() / D\n        valid_ratio_h = valid_H.float() / H\n        valid_ratio_w = valid_W.float() / W\n        valid_ratio = torch.stack([valid_ratio_d, valid_ratio_w, valid_ratio_h], -1)\n        return valid_ratio\n\n    def forward(self, srcs, masks, pos_embeds):\n\n        # prepare input for encoder\n        src_flatten = []\n        mask_flatten = []\n        lvl_pos_embed_flatten = []\n        spatial_shapes = []\n        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):\n            bs, c, d, h, w = src.shape\n            spatial_shape = (d, h, w)\n            spatial_shapes.append(spatial_shape)\n            src = src.flatten(2).transpose(1, 2)\n            mask = mask.flatten(1)\n            pos_embed = pos_embed.flatten(2).transpose(1, 2)\n            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)\n            lvl_pos_embed_flatten.append(lvl_pos_embed)\n            src_flatten.append(src)\n            mask_flatten.append(mask)\n        src_flatten = torch.cat(src_flatten, 1)\n        mask_flatten = torch.cat(mask_flatten, 1)\n        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)\n        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)\n        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))\n        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)\n\n        # encoder\n        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)\n\n        return memory\n\nclass Conv3d_wd(nn.Conv3d):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, bias=False):\n        super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)\n\n    def forward(self, x):\n        weight = self.weight\n        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)\n        weight = weight - weight_mean\n        # std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5\n        std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)\n        weight = weight / std.expand_as(weight)\n        return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n\n\ndef conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1,\n              bias=False, weight_std=False):\n    \"3x3x3 convolution with padding\"\n    if weight_std:\n        return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)\n    else:\n        return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)\n\n\ndef Norm_layer(norm_cfg, inplanes):\n    if norm_cfg == 'BN':\n        out = nn.BatchNorm3d(inplanes)\n    elif norm_cfg == 'SyncBN':\n        out = nn.SyncBatchNorm(inplanes)\n    elif norm_cfg == 'GN':\n        out = nn.GroupNorm(16, inplanes)\n    elif norm_cfg == 'IN':\n        out = nn.InstanceNorm3d(inplanes, affine=True)\n\n    return out\n\n\ndef Activation_layer(activation_cfg, inplace=True):\n    if activation_cfg == 'ReLU':\n        out = nn.ReLU(inplace=inplace)\n    elif activation_cfg == 'LeakyReLU':\n        out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace)\n\n    return out\n\nclass ResBlock(nn.Module):\n    expansion = 1\n    def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False):\n        super(ResBlock, self).__init__()\n\n        self.conv1 = conv3x3x3(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std)\n        self.norm1 = Norm_layer(norm_cfg, planes)\n        self.nonlin = Activation_layer(activation_cfg, inplace=True)\n        self.downsample = downsample\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.norm1(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.nonlin(out)\n\n        return out\n\nclass Backbone(nn.Module):\n\n\n    def __init__(self, depth, in_channels=1, norm_cfg='BN', activation_cfg='ReLU', weight_std=False):\n        super(Backbone, self).__init__()\n\n        self.arch_settings = {\n            9: (ResBlock, (3, 3, 2))\n        }\n\n        if depth not in self.arch_settings:\n            raise KeyError('invalid depth {} for resnet'.format(depth))\n        self.depth = depth\n        block, layers = self.arch_settings[depth]\n        self.inplanes = 64\n        self.conv1 = conv3x3x3(in_channels, 64, kernel_size=7, stride=(1, 2, 2), padding=3, bias=False, weight_std=weight_std)\n        self.norm1 = Norm_layer(norm_cfg, 64)\n        self.nonlin = Activation_layer(activation_cfg, inplace=True)\n        self.layer1 = self._make_layer(block, 192, layers[0], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)\n        self.layer2 = self._make_layer(block, 384, layers[1], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)\n        self.layer3 = self._make_layer(block, 384, layers[2], stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)\n        self.layers = []\n\n        for m in self.modules():\n            if isinstance(m, (nn.Conv3d, Conv3d_wd)):\n                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')\n            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=(1, 1, 1), norm_cfg='BN', activation_cfg='ReLU', weight_std=False):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv3x3x3(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False, weight_std=weight_std), Norm_layer(norm_cfg, planes * block.expansion))\n\n        layers = []\n        layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, stride=stride, downsample=downsample, weight_std=weight_std))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, weight_std=weight_std))\n\n        return nn.Sequential(*layers)\n\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, (nn.Conv3d, Conv3d_wd)):\n                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')\n            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        out = []\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.nonlin(x)\n        out.append(x)\n\n        x = self.layer1(x)\n        out.append(x)\n        x = self.layer2(x)\n        out.append(x)\n        x = self.layer3(x)\n        out.append(x)\n\n        return out\n\nclass Conv3dBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, norm_cfg, activation_cfg, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), bias=False, weight_std=False):\n        super(Conv3dBlock, self).__init__()\n        self.conv = conv3x3x3(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, weight_std=weight_std)\n        self.norm = Norm_layer(norm_cfg, out_channels)\n        self.nonlin = Activation_layer(activation_cfg, inplace=True)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.norm(x)\n        x = self.nonlin(x)\n        return x\n\n\nclass ResBlock_(nn.Module):\n\n    def __init__(self, inplanes, planes, norm_cfg, activation_cfg, weight_std=False):\n        super(ResBlock_, self).__init__()\n        self.resconv1 = Conv3dBlock(inplanes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std)\n        self.resconv2 = Conv3dBlock(planes, planes, norm_cfg, activation_cfg, kernel_size=3, stride=1, padding=1, bias=False, weight_std=weight_std)\n\n    def forward(self, x):\n        residual = x\n\n        out = self.resconv1(x)\n        out = self.resconv2(out)\n        out = out + residual\n\n        return out\n\nclass U_ResTran3D(nn.Module):\n    def __init__(self, in_channels, num_classes, norm_cfg='BN', activation_cfg='ReLU', weight_std=False):\n        super(U_ResTran3D, self).__init__()\n\n        self.MODEL_NUM_CLASSES = num_classes\n\n        self.upsamplex2 = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)\n\n        self.transposeconv_stage2 = nn.ConvTranspose3d(384, 384, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)\n        self.transposeconv_stage1 = nn.ConvTranspose3d(384, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)\n        self.transposeconv_stage0 = nn.ConvTranspose3d(192, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)\n\n        self.stage2_de = ResBlock_(384, 384, norm_cfg, activation_cfg, weight_std=weight_std)\n        self.stage1_de = ResBlock_(192, 192, norm_cfg, activation_cfg, weight_std=weight_std)\n        self.stage0_de = ResBlock_(64, 64, norm_cfg, activation_cfg, weight_std=weight_std)\n\n        # self.ds2_cls_conv = nn.Conv3d(384, self.MODEL_NUM_CLASSES, kernel_size=1)\n        # self.ds1_cls_conv = nn.Conv3d(192, self.MODEL_NUM_CLASSES, kernel_size=1)\n        # self.ds0_cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1)\n\n        self.cls_conv = nn.Conv3d(64, self.MODEL_NUM_CLASSES, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, (nn.Conv3d, Conv3d_wd, nn.ConvTranspose3d)):\n                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')\n            elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm, nn.InstanceNorm3d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n        self.backbone = Backbone(depth=9, in_channels=in_channels, norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)\n        # total = sum([param.nelement() for param in self.backbone.parameters()])\n        # print('  + Number of Backbone Params: %.2f(e6)' % (total / 1e6))\n\n        self.position_embed = build_position_encoding(mode='v2', hidden_dim=384)\n        self.encoder_Detrans = DeformableTransformer(d_model=384, dim_feedforward=1536, dropout=0.1, activation='gelu',\n                                                     num_feature_levels=2, nhead=6, num_encoder_layers=6,\n                                                     enc_n_points=4)\n        # total = sum([param.nelement() for param in self.encoder_Detrans.parameters()])\n        # print('  + Number of Transformer Params: %.2f(e6)' % (total / 1e6))\n\n    def posi_mask(self, x):\n\n        x_fea = []\n        x_posemb = []\n        masks = []\n        for lvl, fea in enumerate(x):\n            if lvl > 1:\n                x_fea.append(fea)\n                x_posemb.append(self.position_embed(fea))\n                masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool).cuda())\n                # masks.append(torch.zeros((fea.shape[0], fea.shape[2], fea.shape[3], fea.shape[4]), dtype=torch.bool))\n        return x_fea, masks, x_posemb\n\n    def forward(self, inputs):\n        # # %%%%%%%%%%%%% CoTr\n        x_convs = self.backbone(inputs)\n        x_fea, masks, x_posemb = self.posi_mask(x_convs)\n        x_trans = self.encoder_Detrans(x_fea, masks, x_posemb)\n\n        # # Single_scale\n        # # x = self.transposeconv_stage2(x_trans.transpose(-1, -2).view(x_convs[-1].shape))\n        # # skip2 = x_convs[-2]\n        # Multi-scale\n        x = self.transposeconv_stage2(x_trans[:, x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]::].transpose(-1, -2).view(x_convs[-1].shape))  # x_trans length: 12*24*24+6*12*12=7776\n        skip2 = x_trans[:, 0:x_fea[0].shape[-3] * x_fea[0].shape[-2] * x_fea[0].shape[-1]].transpose(-1, -2).view(x_convs[-2].shape)\n\n        x = x + skip2\n        x = self.stage2_de(x)\n        # ds2 = self.ds2_cls_conv(x)\n\n        x = self.transposeconv_stage1(x)\n        skip1 = x_convs[-3]\n        x = x + skip1\n        x = self.stage1_de(x)\n        # ds1 = self.ds1_cls_conv(x)\n\n        x = self.transposeconv_stage0(x)\n        skip0 = x_convs[-4]\n        x = x + skip0\n        x = self.stage0_de(x)\n        # ds0 = self.ds0_cls_conv(x)\n\n        result = self.upsamplex2(x)\n        result = self.cls_conv(result)\n\n        return result\n\n\ndef cotr(in_channels, num_classes):\n    model = U_ResTran3D(in_channels, num_classes)\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#\n#     mask = torch.ones(2, 64, 64, 64).long()\n#     model = cotr(1, 10)\n#     model.train()\n#     input = torch.rand(2, 1, 64, 64, 64)\n#     output = model(input)\n#     loss_train = criterion(output, mask)\n#     output = output.data.cpu().numpy()\n#     loss_train.backward()\n#     print(output.shape)\n#     print(loss_train)\n"
  },
  {
    "path": "models/networks_3d/dmfnet.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch\n# from loss.loss_function import segmentation_loss\n\ndef normalization(planes, norm='bn'):\n    if norm == 'bn':\n        m = nn.BatchNorm3d(planes)\n    elif norm == 'gn':\n        m = nn.GroupNorm(4, planes)\n    elif norm == 'in':\n        m = nn.InstanceNorm3d(planes)\n    else:\n        raise ValueError('normalization type {} is not supported'.format(norm))\n    return m\n\nclass Conv3d_Block(nn.Module):\n    def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=None,norm=None):\n        super(Conv3d_Block, self).__init__()\n        if padding == None:\n            padding = (kernel_size - 1) // 2\n        self.bn = normalization(num_in,norm=norm)\n        self.act_fn = nn.ReLU(inplace=True)\n        self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding,stride=stride, groups=g, bias=False)\n\n    def forward(self, x): # BN + Relu + Conv\n        h = self.act_fn(self.bn(x))\n        h = self.conv(h)\n        return h\n\n\nclass DilatedConv3DBlock(nn.Module):\n    def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1, d=(1,1,1), norm=None):\n        super(DilatedConv3DBlock, self).__init__()\n        assert isinstance(kernel_size,tuple) and isinstance(d,tuple)\n\n        padding = tuple(\n            [(ks-1)//2 *dd for ks, dd in zip(kernel_size, d)]\n        )\n\n        self.bn = normalization(num_in, norm=norm)\n        self.act_fn = nn.ReLU(inplace=True)\n        self.conv = nn.Conv3d(num_in,num_out,kernel_size=kernel_size,padding=padding,stride=stride,groups=g,dilation=d,bias=False)\n\n    def forward(self, x):\n        h = self.act_fn(self.bn(x))\n        h = self.conv(h)\n        return h\n\n\nclass MFunit(nn.Module):\n    def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None):\n        \"\"\"  The second 3x3x1 group conv is replaced by 3x3x3.\n        :param num_in: number of input channels\n        :param num_out: number of output channels\n        :param g: groups of group conv.\n        :param stride: 1 or 2\n        :param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv\n        :param norm: Batch Normalization\n        \"\"\"\n        super(MFunit, self).__init__()\n        num_mid = num_in if num_in <= num_out else num_out\n        self.conv1x1x1_in1 = Conv3d_Block(num_in,num_in//4,kernel_size=1,stride=1,norm=norm)\n        self.conv1x1x1_in2 = Conv3d_Block(num_in//4,num_mid,kernel_size=1,stride=1,norm=norm)\n        self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid,num_out,kernel_size=(3,3,3),stride=stride,g=g,d=(d[0],d[0],d[0]),norm=norm) # dilated\n        self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(3,3,1),stride=1,g=g,d=(d[1],d[1],1),norm=norm)\n        # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm)\n\n        # skip connection\n        if num_in != num_out or stride != 1:\n            if stride == 1:\n                self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0,norm=norm)\n            if stride == 2:\n                # if MF block with stride=2, 2x2x2\n                self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2,padding=0, norm=norm) # params\n\n    def forward(self, x):\n        x1 = self.conv1x1x1_in1(x)\n        x2 = self.conv1x1x1_in2(x1)\n        x3 = self.conv3x3x3_m1(x2)\n        x4 = self.conv3x3x3_m2(x3)\n\n        shortcut = x\n\n        if hasattr(self,'conv1x1x1_shortcut'):\n            shortcut = self.conv1x1x1_shortcut(shortcut)\n        if hasattr(self,'conv2x2x2_shortcut'):\n            shortcut = self.conv2x2x2_shortcut(shortcut)\n\n        return x4 + shortcut\n\nclass DMFUnit(nn.Module):\n    # weighred add\n    def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=None):\n        super(DMFUnit, self).__init__()\n        self.weight1 = nn.Parameter(torch.ones(1))\n        self.weight2 = nn.Parameter(torch.ones(1))\n        self.weight3 = nn.Parameter(torch.ones(1))\n\n        num_mid = num_in if num_in <= num_out else num_out\n\n        self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)\n        self.conv1x1x1_in2 = Conv3d_Block(num_in // 4,num_mid,kernel_size=1, stride=1, norm=norm)\n\n        self.conv3x3x3_m1 = nn.ModuleList()\n        if dilation == None:\n            dilation = [1,2,3]\n        for i in range(3):\n            self.conv3x3x3_m1.append(\n                DilatedConv3DBlock(num_mid,num_out, kernel_size=(3, 3, 3), stride=stride, g=g, d=(dilation[i],dilation[i], dilation[i]),norm=norm)\n            )\n\n        # It has not Dilated operation\n        self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)\n        # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)\n\n        # skip connection\n        if num_in != num_out or stride != 1:\n            if stride == 1:\n                self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)\n            if stride == 2:\n                self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm)\n\n\n    def forward(self, x):\n        x1 = self.conv1x1x1_in1(x)\n        x2 = self.conv1x1x1_in2(x1)\n        x3 = self.weight1*self.conv3x3x3_m1[0](x2) + self.weight2*self.conv3x3x3_m1[1](x2) + self.weight3*self.conv3x3x3_m1[2](x2)\n        x4 = self.conv3x3x3_m2(x3)\n        shortcut = x\n        if hasattr(self, 'conv1x1x1_shortcut'):\n            shortcut = self.conv1x1x1_shortcut(shortcut)\n        if hasattr(self, 'conv2x2x2_shortcut'):\n            shortcut = self.conv2x2x2_shortcut(shortcut)\n        return x4 + shortcut\n\n\nclass MFNet(nn.Module): #\n    # [96]   Flops:  13.361G  &  Params: 1.81M\n    # [112]  Flops:  16.759G  &  Params: 2.46M\n    # [128]  Flops:  20.611G  &  Params: 3.19M\n    def __init__(self,in_channels, num_classes, n=32, channels=128, groups=16, norm='bn'):\n        super(MFNet, self).__init__()\n\n        # Entry flow\n        self.encoder_block1 = nn.Conv3d(in_channels, n, kernel_size=3, padding=1, stride=2, bias=False)# H//2\n        self.encoder_block2 = nn.Sequential(\n            MFunit(n, channels, g=groups, stride=2, norm=norm),# H//4 down\n            MFunit(channels, channels, g=groups, stride=1, norm=norm),\n            MFunit(channels, channels, g=groups, stride=1, norm=norm)\n        )\n        #\n        self.encoder_block3 = nn.Sequential(\n            MFunit(channels, channels*2, g=groups, stride=2, norm=norm), # H//8\n            MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm),\n            MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm)\n        )\n\n        self.encoder_block4 = nn.Sequential(# H//8,channels*4\n            MFunit(channels*2, channels*3, g=groups, stride=2, norm=norm), # H//16\n            MFunit(channels*3, channels*3, g=groups, stride=1, norm=norm),\n            MFunit(channels*3, channels*2, g=groups, stride=1, norm=norm),\n        )\n\n        self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8\n        self.decoder_block1 = MFunit(channels*2+channels*2, channels*2, g=groups, stride=1, norm=norm)\n\n        self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4\n        self.decoder_block2 = MFunit(channels*2 + channels, channels, g=groups, stride=1, norm=norm)\n\n        self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2\n        self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm)\n        self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H\n        self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0,stride=1,bias=False)\n\n        # Initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv3d):\n                torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) #\n            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        # Encoder\n        x1 = self.encoder_block1(x)# H//2 down\n        x2 = self.encoder_block2(x1)# H//4 down\n        x3 = self.encoder_block3(x2)# H//8 down\n        x4 = self.encoder_block4(x3) # H//16\n        # Decoder\n        y1 = self.upsample1(x4)# H//8\n        y1 = torch.cat([x3,y1],dim=1)\n        y1 = self.decoder_block1(y1)\n\n        y2 = self.upsample2(y1)# H//4\n        y2 = torch.cat([x2,y2],dim=1)\n        y2 = self.decoder_block2(y2)\n\n        y3 = self.upsample3(y2)# H//2\n        y3 = torch.cat([x1,y3],dim=1)\n        y3 = self.decoder_block3(y3)\n        y4 = self.upsample4(y3)\n        y4 = self.seg(y4)\n\n        return y4\n\n\nclass DMFNet(MFNet): # softmax\n    # [128]  Flops:  27.045G  &  Params: 3.88M\n    def __init__(self,in_channels, num_classes, n=32,channels=128, groups=16,norm='bn'):\n        super(DMFNet, self).__init__(in_channels, num_classes, n, channels, groups, norm)\n\n        self.encoder_block2 = nn.Sequential(\n            DMFUnit(n, channels, g=groups, stride=2, norm=norm,dilation=[1,2,3]),# H//4 down\n            DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]), # Dilated Conv 3\n            DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3])\n        )\n\n        self.encoder_block3 = nn.Sequential(\n            DMFUnit(channels, channels*2, g=groups, stride=2, norm=norm,dilation=[1,2,3]), # H//8\n            DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]),# Dilated Conv 3\n            DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3])\n        )\n\ndef dmfnet(in_channels, num_classes):\n    model = DMFNet(in_channels, num_classes)\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 64, 64, 64).long()\n#\n#     model = dmfnet(1, 10)\n#     model.train()\n#     input = torch.rand(2, 1, 64, 64, 64)\n#     output = model(input)\n#\n#     loss_train = criterion(output, mask)\n#     loss_train.backward()\n#\n#     output = output.data.cpu().numpy()\n#     print(output.shape)\n#     print(loss_train)\n"
  },
  {
    "path": "models/networks_3d/espnet3d.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n# from loss.loss_function import segmentation_loss\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=UserWarning)\n\nclass CBR(nn.Module):\n    def __init__(self, nIn, nOut, kSize, stride=1):\n        super().__init__()\n        padding = int((kSize - 1) / 2)\n        self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)\n        self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)\n        self.act = nn.ReLU(inplace=True)\n\n    def forward(self, input):\n        output = self.conv(input)\n        output = self.bn(output)\n        output = self.act(output)\n        return output\n\n\nclass CB(nn.Module):\n    def __init__(self, nIn, nOut, kSize, stride=1):\n        super().__init__()\n        padding = int((kSize - 1) / 2)\n        self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)\n        self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)\n\n    def forward(self, input):\n        output = self.conv(input)\n        output = self.bn(output)\n        return output\n\n\nclass C(nn.Module):\n    def __init__(self, nIn, nOut, kSize, stride=1, groups=1):\n        super().__init__()\n        padding = int((kSize - 1) / 2)\n        self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups)\n\n    def forward(self, input):\n        output = self.conv(input)\n        return output\n\n\nclass DownSamplerA(nn.Module):\n    def __init__(self, nIn, nOut):\n        super().__init__()\n        self.conv = CBR(nIn, nOut, 3, 2)\n\n    def forward(self, input):\n        output = self.conv(input)\n        return output\n\n\nclass DownSamplerB(nn.Module):\n    def __init__(self, nIn, nOut):\n        super().__init__()\n        k = 4\n        n = int(nOut/k)\n        n1 = nOut - (k-1)*n\n        self.c1 = nn.Sequential(CBR(nIn, n, 1, 1), C(n, n, 3, 2))\n        self.d1 = CDilated(n, n1, 3, 1, 1)\n        self.d2 = CDilated(n, n, 3, 1, 2)\n        self.d4 = CDilated(n, n, 3, 1, 3)\n        self.d8 = CDilated(n, n, 3, 1, 4)\n        self.bn = BR(nOut)\n\n    def forward(self, input):\n        output1 = self.c1(input)\n        d1 = self.d1(output1)\n        d2 = self.d2(output1)\n        d4 = self.d4(output1)\n        d8 = self.d8(output1)\n\n        add1 = d2\n        add2 = add1 + d4\n        add3 = add2 + d8\n\n        combine = torch.cat([d1, add1, add2, add3],1)\n        if input.size() == combine.size():\n            combine = input + combine\n        output = self.bn(combine)\n        return output\n\n\nclass BR(nn.Module):\n    def __init__(self, nOut):\n        super().__init__()\n        self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)\n        self.act = nn.ReLU(inplace=True)  # nn.PReLU(nOut)\n\n    def forward(self, input):\n        output = self.bn(input)\n        output = self.act(output)\n        return output\n\n\nclass CDilated(nn.Module):\n    def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):\n        super().__init__()\n        padding = int((kSize - 1) / 2) * d\n        self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d, groups=groups)\n        #self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03)\n\n    def forward(self, input):\n        return self.conv(input)\n        #return self.bn(output)\n\n\nclass InputProjectionA(nn.Module):\n    '''\n    This class projects the input image to the same spatial dimensions as the feature map.\n    For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then\n    this class will generate an output of 56x56x3\n    '''\n\n    def __init__(self, samplingTimes):\n        '''\n        :param samplingTimes: The rate at which you want to down-sample the image\n        '''\n        super().__init__()\n        self.pool = nn.ModuleList()\n        for i in range(0, samplingTimes):\n            # pyramid-based approach for down-sampling\n            self.pool.append(nn.AvgPool3d(3, stride=2, padding=1))\n\n    def forward(self, input):\n        '''\n        :param input: Input RGB Image\n        :return: down-sampled image (pyramid-based approach)\n        '''\n        for pool in self.pool:\n            input = pool(input)\n        return input\n\n\nclass DilatedParllelResidualBlockB1(nn.Module):  # with k=4\n    def __init__(self, nIn, nOut, stride=1):\n        super().__init__()\n        k = 4\n        n = int(nOut / k)\n        n1 = nOut - (k - 1) * n\n        self.c1 = CBR(nIn, n, 1, 1)\n        self.d1 = CDilated(n, n1, 3, stride, 1)\n        self.d2 = CDilated(n, n, 3, stride, 1)\n        self.d4 = CDilated(n, n, 3, stride, 2)\n        self.d8 = CDilated(n, n, 3, stride, 2)\n        self.bn = nn.BatchNorm3d(nOut)\n\n    def forward(self, input):\n        output1 = self.c1(input)\n        d1 = self.d1(output1)\n        d2 = self.d2(output1)\n        d4 = self.d4(output1)\n        d8 = self.d8(output1)\n\n        add1 = d2\n        add2 = add1 + d4\n        add3 = add2 + d8\n\n        combine = self.bn(torch.cat([d1, add1, add2, add3], 1))\n        if input.size() == combine.size():\n            combine = input + combine\n        output = F.relu(combine, inplace=True)\n        return output\n\nclass ASPBlock(nn.Module):  # with k=4\n    def __init__(self, nIn, nOut, stride=1):\n        super().__init__()\n        self.d1 = CB(nIn, nOut, 3, 1)\n        self.d2 = CB(nIn, nOut, 5, 1)\n        self.d4 = CB(nIn, nOut, 7, 1)\n        self.d8 = CB(nIn, nOut, 9, 1)\n        self.act = nn.ReLU(inplace=True)\n\n    def forward(self, input):\n        d1 = self.d1(input)\n        d2 = self.d2(input)\n        d3 = self.d4(input)\n        d4 = self.d8(input)\n\n        combine = d1 + d2 + d3 + d4\n        if input.size() == combine.size():\n            combine = input + combine\n        output = self.act(combine)\n        return output\n\n\nclass UpSampler(nn.Module):\n    '''\n    Up-sample the feature maps by 2\n    '''\n    def __init__(self, nIn, nOut):\n        super().__init__()\n        self.up = CBR(nIn, nOut, 3, 1)\n\n    def forward(self, inp):\n        return F.upsample(self.up(inp), mode='trilinear', scale_factor=2, align_corners=True)\n\n\nclass PSPDec(nn.Module):\n    '''\n    Inspired or Adapted from Pyramid Scene Network paper\n    '''\n\n    def __init__(self, nIn, nOut, downSize):\n        super().__init__()\n        self.scale = downSize\n        self.features = CBR(nIn, nOut, 3, 1)\n    def forward(self, x):\n        assert x.dim() == 5\n        inp_size = x.size()\n        out_dim1, out_dim2, out_dim3 = int(inp_size[2] * self.scale), int(inp_size[3] * self.scale), int(inp_size[4] * self.scale)\n        x_down = F.adaptive_avg_pool3d(x, output_size=(out_dim1, out_dim2, out_dim3))\n        return F.upsample(self.features(x_down), size=(inp_size[2], inp_size[3], inp_size[4]), mode='trilinear', align_corners=True)\n\nclass ESPNet(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super().__init__()\n        self.input1 = InputProjectionA(1)\n        self.input2 = InputProjectionA(1)\n\n        initial = 16 # feature maps at level 1\n        config = [32, 128, 256, 256] # feature maps at level 2 and onwards\n        reps = [2, 2, 3]\n\n        ### ENCODER\n\n        # all dimensions are listed with respect to an input  of size 4 x 128 x 128 x 128\n        self.level0 = CBR(in_channels, initial, 7, 2) # initial x 64 x 64 x64\n        self.level1 = nn.ModuleList()\n        for i in range(reps[0]):\n            if i==0:\n                self.level1.append(DilatedParllelResidualBlockB1(initial, config[0]))  # config[0] x 64 x 64 x64\n            else:\n                self.level1.append(DilatedParllelResidualBlockB1(config[0], config[0]))  # config[0] x 64 x 64 x64\n\n        # downsample the feature maps\n        self.level2 = DilatedParllelResidualBlockB1(config[0], config[1], stride=2) # config[1] x 32 x 32 x 32\n        self.level_2 = nn.ModuleList()\n        for i in range(0, reps[1]):\n            self.level_2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32\n\n        # downsample the feature maps\n        self.level3_0 = DilatedParllelResidualBlockB1(config[1], config[2], stride=2) # config[2] x 16 x 16 x 16\n        self.level_3 = nn.ModuleList()\n        for i in range(0, reps[2]):\n            self.level_3.append(DilatedParllelResidualBlockB1(config[2], config[2])) # config[2] x 16 x 16 x 16\n\n\n        ### DECODER\n\n        # upsample the feature maps\n        self.up_l3_l2 = UpSampler(config[2], config[1])  # config[1] x 32 x 32 x 32\n        # Note the 2 in below line. You need this because you are concatenating feature maps from encoder\n        # with upsampled feature maps\n        self.merge_l2 = DilatedParllelResidualBlockB1(2 * config[1], config[1]) # config[1] x 32 x 32 x 32\n        self.dec_l2 = nn.ModuleList()\n        for i in range(0, reps[0]):\n            self.dec_l2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32\n\n        self.up_l2_l1 = UpSampler(config[1], config[0])  # config[0] x 64 x 64 x 64\n        # Note the 2 in below line. You need this because you are concatenating feature maps from encoder\n        # with upsampled feature maps\n        self.merge_l1 = DilatedParllelResidualBlockB1(2*config[0], config[0]) # config[0] x 64 x 64 x 64\n        self.dec_l1 = nn.ModuleList()\n        for i in range(0, reps[0]):\n            self.dec_l1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x 64\n\n        self.dec_l1.append(CBR(config[0], num_classes, 3, 1)) # classes x 64 x 64 x 64\n        # We use ESP block without reduction step because the number  of input feature maps are very small (i.e. 4 in\n        # our case)\n        self.dec_l1.append(ASPBlock(num_classes, num_classes))\n\n        # Using PSP module to learn the representations at different scales\n        self.pspModules = nn.ModuleList()\n        scales = [0.2, 0.4, 0.6, 0.8]\n        for sc in scales:\n             self.pspModules.append(PSPDec(num_classes, num_classes, sc))\n\n        # Classifier\n        self.classifier = self.classifier = nn.Sequential(\n             CBR((len(scales) + 1) * num_classes, num_classes, 3, 1),\n             ASPBlock(num_classes, num_classes), # classes x 64 x 64 x 64\n             nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True), # classes x 128 x 128 x 128\n             CBR(num_classes, num_classes, 7, 1), # classes x 128 x 128 x 128\n             C(num_classes, num_classes, 1, 1) # classes x 128 x 128 x 128\n        )\n        #\n\n        for m in self.modules():\n             if isinstance(m, nn.Conv3d):\n                 n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels\n                 m.weight.data.normal_(0, math.sqrt(2. / n))\n             if isinstance(m, nn.ConvTranspose3d):\n                 n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels\n                 m.weight.data.normal_(0, math.sqrt(2. / n))\n             elif isinstance(m, nn.BatchNorm3d):\n                 m.weight.data.fill_(1)\n                 m.bias.data.zero_()\n\n    def forward(self, input1, inp_res=(128, 128, 128), inpSt2=False):\n        dim0 = input1.size(2)\n        dim1 = input1.size(3)\n        dim2 = input1.size(4)\n\n        if self.training or inp_res is None:\n            # input resolution should be divisible by 8\n            inp_res = (math.ceil(dim0 / 8) * 8, math.ceil(dim1 / 8) * 8,\n                       math.ceil(dim2 / 8) * 8)\n        if inp_res:\n            input1 = F.adaptive_avg_pool3d(input1, output_size=inp_res)\n\n        out_l0 = self.level0(input1)\n\n        for i, layer in enumerate(self.level1): #64\n            if i == 0:\n                out_l1 = layer(out_l0)\n            else:\n                out_l1 = layer(out_l1)\n\n        out_l2_down = self.level2(out_l1) #32\n        for i, layer in enumerate(self.level_2):\n            if i == 0:\n                out_l2 = layer(out_l2_down)\n            else:\n                out_l2 = layer(out_l2)\n        del out_l2_down\n\n        out_l3_down = self.level3_0(out_l2) #16\n        for i, layer in enumerate(self.level_3):\n            if i == 0:\n                out_l3 = layer(out_l3_down)\n            else:\n                out_l3 = layer(out_l3)\n        del out_l3_down\n\n        dec_l3_l2 = self.up_l3_l2(out_l3)\n        merge_l2 = self.merge_l2(torch.cat([dec_l3_l2, out_l2], 1))\n        for i, layer in enumerate(self.dec_l2):\n            if i == 0:\n                dec_l2 = layer(merge_l2)\n            else:\n                dec_l2 = layer(dec_l2)\n\n        dec_l2_l1 = self.up_l2_l1(dec_l2)\n        merge_l1 = self.merge_l1(torch.cat([dec_l2_l1, out_l1], 1))\n        for i, layer in enumerate(self.dec_l1):\n            if i == 0:\n                dec_l1 = layer(merge_l1)\n            else:\n                dec_l1 = layer(dec_l1)\n\n        psp_outs = dec_l1.clone()\n        for layer in self.pspModules:\n            out_psp = layer(dec_l1)\n            psp_outs = torch.cat([psp_outs, out_psp], 1)\n\n        decoded = self.classifier(psp_outs)\n        return F.upsample(decoded, size=(dim0, dim1, dim2), mode='trilinear', align_corners=True)\n\ndef espnet3d(in_channels, num_classes):\n    model = ESPNet(in_channels, num_classes)\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#\n#     mask = torch.ones(2, 96, 48, 96).long()\n#     model = espnet3d(1, 10)\n#     model.train()\n#     input = torch.rand(2, 1, 96, 48, 96)\n#     output = model(input)\n#     loss_train = criterion(output, mask)\n#     output = output.data.cpu().numpy()\n#     loss_train.backward()\n#     print(output.shape)\n#     print(loss_train)"
  },
  {
    "path": "models/networks_3d/res_unet3d.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass UNet(nn.Module):\n    \"\"\"\n    Implementations based on the Unet3D paper: https://arxiv.org/pdf/1706.00120.pdf\n    \"\"\"\n\n    def __init__(self, in_channels, n_classes, base_n_filter=8):\n        super(UNet, self).__init__()\n        self.in_channels = in_channels\n        self.n_classes = n_classes\n        self.base_n_filter = base_n_filter\n\n        self.lrelu = nn.LeakyReLU()\n        self.dropout3d = nn.Dropout3d(p=0.6)\n        self.upsacle = nn.Upsample(scale_factor=2, mode='nearest')\n\n        self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)\n        self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)\n        self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)\n        self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)\n\n        self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False)\n        self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2)\n        self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2)\n\n        self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False)\n        self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4)\n        self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4)\n\n        self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False)\n        self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8)\n        self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8)\n\n        self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False)\n        self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16)\n        self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 8)\n\n        self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False)\n        self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8)\n\n        self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16)\n        self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False)\n        self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 4)\n\n        self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8)\n        self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0,\n                                   bias=False)\n        self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 2)\n\n        self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4)\n        self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False)\n        self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter)\n\n        self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2)\n        self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)\n\n        self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)\n        self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)\n\n    def conv_norm_lrelu(self, feat_in, feat_out):\n        return nn.Sequential(\n            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.InstanceNorm3d(feat_out),\n            nn.LeakyReLU())\n\n    def norm_lrelu_conv(self, feat_in, feat_out):\n        return nn.Sequential(\n            nn.InstanceNorm3d(feat_in),\n            nn.LeakyReLU(),\n            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))\n\n    def lrelu_conv(self, feat_in, feat_out):\n        return nn.Sequential(\n            nn.LeakyReLU(),\n            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))\n\n    def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):\n        return nn.Sequential(\n            nn.InstanceNorm3d(feat_in),\n            nn.LeakyReLU(),\n            nn.Upsample(scale_factor=2, mode='nearest'),\n            # should be feat_in*2 or feat_in\n            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.InstanceNorm3d(feat_out),\n            nn.LeakyReLU())\n\n    def forward(self, x):\n        #  Level 1 context pathway\n        out = self.conv3d_c1_1(x)\n        residual_1 = out\n        out = self.lrelu(out)\n        out = self.conv3d_c1_2(out)\n        out = self.dropout3d(out)\n        out = self.lrelu_conv_c1(out)\n        # Element Wise Summation\n        out += residual_1\n        context_1 = self.lrelu(out)\n        out = self.inorm3d_c1(out)\n        out = self.lrelu(out)\n\n        # Level 2 context pathway\n        out = self.conv3d_c2(out)\n        residual_2 = out\n        out = self.norm_lrelu_conv_c2(out)\n        out = self.dropout3d(out)\n        out = self.norm_lrelu_conv_c2(out)\n        out += residual_2\n        out = self.inorm3d_c2(out)\n        out = self.lrelu(out)\n        context_2 = out\n\n        # Level 3 context pathway\n        out = self.conv3d_c3(out)\n        residual_3 = out\n        out = self.norm_lrelu_conv_c3(out)\n        out = self.dropout3d(out)\n        out = self.norm_lrelu_conv_c3(out)\n        out += residual_3\n        out = self.inorm3d_c3(out)\n        out = self.lrelu(out)\n        context_3 = out\n\n        # Level 4 context pathway\n        out = self.conv3d_c4(out)\n        residual_4 = out\n        out = self.norm_lrelu_conv_c4(out)\n        out = self.dropout3d(out)\n        out = self.norm_lrelu_conv_c4(out)\n        out += residual_4\n        out = self.inorm3d_c4(out)\n        out = self.lrelu(out)\n        context_4 = out\n\n        # Level 5\n        out = self.conv3d_c5(out)\n        residual_5 = out\n        out = self.norm_lrelu_conv_c5(out)\n        out = self.dropout3d(out)\n        out = self.norm_lrelu_conv_c5(out)\n        out += residual_5\n        out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)\n\n        out = self.conv3d_l0(out)\n        out = self.inorm3d_l0(out)\n        out = self.lrelu(out)\n\n        # Level 1 localization pathway\n        out = torch.cat([out, context_4], dim=1)\n        out = self.conv_norm_lrelu_l1(out)\n        out = self.conv3d_l1(out)\n        out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)\n\n        # Level 2 localization pathway\n        # print(out.shape)\n        # print(context_3.shape)\n        out = torch.cat([out, context_3], dim=1)\n        out = self.conv_norm_lrelu_l2(out)\n        ds2 = out\n        out = self.conv3d_l2(out)\n        out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)\n\n        # Level 3 localization pathway\n        out = torch.cat([out, context_2], dim=1)\n        out = self.conv_norm_lrelu_l3(out)\n        ds3 = out\n        out = self.conv3d_l3(out)\n        out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)\n\n        # Level 4 localization pathway\n        out = torch.cat([out, context_1], dim=1)\n        out = self.conv_norm_lrelu_l4(out)\n        out_pred = self.conv3d_l4(out)\n\n        ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)\n        ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)\n        ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)\n        ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv\n        ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)\n\n        out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale\n        seg_layer = out\n        return seg_layer\n\ndef res_unet3d(in_channels, num_classes):\n    model = UNet(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#     model = res_unet3d(1,10)\n#     model.eval()\n#     input = torch.rand(2, 1, 128, 128, 128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n"
  },
  {
    "path": "models/networks_3d/transbts.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\nfrom loss.loss_function import segmentation_loss\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\ndef normalization(planes, norm='gn'):\n    if norm == 'bn':\n        m = nn.BatchNorm3d(planes)\n    elif norm == 'gn':\n        m = nn.GroupNorm(8, planes)\n    elif norm == 'in':\n        m = nn.InstanceNorm3d(planes)\n    else:\n        raise ValueError('normalization type {} is not supported'.format(norm))\n    return m\n\nclass InitConv(nn.Module):\n    def __init__(self, in_channels=4, out_channels=16, dropout=0.2):\n        super(InitConv, self).__init__()\n\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)\n        self.dropout = dropout\n\n    def forward(self, x):\n        y = self.conv(x)\n        y = F.dropout3d(y, self.dropout)\n\n        return y\n\nclass EnBlock(nn.Module):\n    def __init__(self, in_channels, norm='gn'):\n        super(EnBlock, self).__init__()\n\n        self.bn1 = normalization(in_channels, norm=norm)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n\n        self.bn2 = normalization(in_channels, norm=norm)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        x1 = self.bn1(x)\n        x1 = self.relu1(x1)\n        x1 = self.conv1(x1)\n        y = self.bn2(x1)\n        y = self.relu2(y)\n        y = self.conv2(y)\n        y = y + x\n\n        return y\n\nclass EnDown(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(EnDown, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x):\n        y = self.conv(x)\n\n        return y\n\nclass Unet(nn.Module):\n    def __init__(self, in_channels=4, base_channels=16):\n        super(Unet, self).__init__()\n\n        self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)\n        self.EnBlock1 = EnBlock(in_channels=base_channels)\n        self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)\n\n        self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)\n        self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)\n        self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)\n\n        self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)\n        self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)\n        self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)\n\n        self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)\n        self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)\n        self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)\n        self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)\n\n    def forward(self, x):\n        x = self.InitConv(x)       # (1, 16, 128, 128, 128)\n\n        x1_1 = self.EnBlock1(x)\n        x1_2 = self.EnDown1(x1_1)  # (1, 32, 64, 64, 64)\n\n        x2_1 = self.EnBlock2_1(x1_2)\n        x2_1 = self.EnBlock2_2(x2_1)\n        x2_2 = self.EnDown2(x2_1)  # (1, 64, 32, 32, 32)\n\n        x3_1 = self.EnBlock3_1(x2_2)\n        x3_1 = self.EnBlock3_2(x3_1)\n        x3_2 = self.EnDown3(x3_1)  # (1, 128, 16, 16, 16)\n\n        x4_1 = self.EnBlock4_1(x3_2)\n        x4_2 = self.EnBlock4_2(x4_1)\n        x4_3 = self.EnBlock4_3(x4_2)\n        output = self.EnBlock4_4(x4_3)  # (1, 128, 16, 16, 16)\n\n        return x1_1, x2_1, x3_1, output\n\nclass FixedPositionalEncoding(nn.Module):\n    def __init__(self, embedding_dim, max_length=512):\n        super(FixedPositionalEncoding, self).__init__()\n\n        pe = torch.zeros(max_length, embedding_dim)\n        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        x = x + self.pe[: x.size(0), :]\n        return x\n\n\nclass LearnedPositionalEncoding(nn.Module):\n    def __init__(self, max_position_embeddings, embedding_dim):\n        super(LearnedPositionalEncoding, self).__init__()\n\n        self.position_embeddings = nn.Parameter(torch.zeros(1, max_position_embeddings, embedding_dim)) #8x\n\n    def forward(self, x):\n\n        position_embeddings = self.position_embeddings\n        return x + position_embeddings\n\nclass IntermediateSequential(nn.Sequential):\n    def __init__(self, *args, return_intermediate=True):\n        super().__init__(*args)\n        self.return_intermediate = return_intermediate\n\n    def forward(self, input):\n        if not self.return_intermediate:\n            return super().forward(input)\n\n        intermediate_outputs = {}\n        output = input\n        for name, module in self.named_children():\n            output = intermediate_outputs[name] = module(output)\n\n        return output, intermediate_outputs\n\nclass SelfAttention(nn.Module):\n    def __init__(\n        self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0\n    ):\n        super().__init__()\n        self.num_heads = heads\n        head_dim = dim // heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(dropout_rate)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(dropout_rate)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))\n        q, k, v = (qkv[0], qkv[1], qkv[2])  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\nclass Residual(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x):\n        return self.fn(x) + x\n\nclass PreNorm(nn.Module):\n    def __init__(self, dim, fn):\n        super().__init__()\n        self.norm = nn.LayerNorm(dim)\n        self.fn = fn\n\n    def forward(self, x):\n        return self.fn(self.norm(x))\n\n\nclass PreNormDrop(nn.Module):\n    def __init__(self, dim, dropout_rate, fn):\n        super().__init__()\n        self.norm = nn.LayerNorm(dim)\n        self.dropout = nn.Dropout(p=dropout_rate)\n        self.fn = fn\n\n    def forward(self, x):\n        return self.dropout(self.fn(self.norm(x)))\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, hidden_dim, dropout_rate):\n        super().__init__()\n        self.net = nn.Sequential(\n            nn.Linear(dim, hidden_dim),\n            nn.GELU(),\n            nn.Dropout(p=dropout_rate),\n            nn.Linear(hidden_dim, dim),\n            nn.Dropout(p=dropout_rate),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\nclass TransformerModel(nn.Module):\n    def __init__(self,dim,depth,heads,mlp_dim,dropout_rate=0.1,attn_dropout_rate=0.1):\n        super().__init__()\n        layers = []\n        for _ in range(depth):\n            layers.extend([\n                Residual(PreNormDrop(dim,dropout_rate,SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate))),\n                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))])\n            # dim = dim / 2\n        self.net = IntermediateSequential(*layers)\n\n    def forward(self, x):\n        return self.net(x)\n\n\nclass TransformerBTS(nn.Module):\n    def __init__(\n        self,\n        img_dim,\n        patch_dim,\n        num_channels,\n        embedding_dim,\n        num_heads,\n        num_layers,\n        hidden_dim,\n        dropout_rate=0.0,\n        attn_dropout_rate=0.0,\n        conv_patch_representation=True,\n        positional_encoding_type=\"learned\",\n    ):\n        super(TransformerBTS, self).__init__()\n\n        assert embedding_dim % num_heads == 0\n        assert img_dim[0] % patch_dim == 0\n        assert img_dim[1] % patch_dim == 0\n        assert img_dim[2] % patch_dim == 0\n\n        self.img_dim = img_dim\n        self.embedding_dim = embedding_dim\n        self.num_heads = num_heads\n        self.patch_dim = patch_dim\n        self.num_channels = num_channels\n        self.dropout_rate = dropout_rate\n        self.attn_dropout_rate = attn_dropout_rate\n        self.conv_patch_representation = conv_patch_representation\n\n        self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))\n        self.seq_length = self.num_patches\n        self.flatten_dim = 128 * num_channels\n\n        self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)\n        if positional_encoding_type == \"learned\":\n            self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim)\n        elif positional_encoding_type == \"fixed\":\n            self.position_encoding = FixedPositionalEncoding(self.embedding_dim)\n\n        self.pe_dropout = nn.Dropout(p=self.dropout_rate)\n\n        self.transformer = TransformerModel(embedding_dim,num_layers,num_heads,hidden_dim,self.dropout_rate,self.attn_dropout_rate)\n        self.pre_head_ln = nn.LayerNorm(embedding_dim)\n\n        if self.conv_patch_representation:\n            self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)\n\n        self.Unet = Unet(in_channels=num_channels, base_channels=16)\n        self.bn = nn.BatchNorm3d(128)\n        self.relu = nn.ReLU(inplace=True)\n\n    def encode(self, x):\n        if self.conv_patch_representation:\n            # combine embedding with conv patch distribution\n            x1_1, x2_1, x3_1, x = self.Unet(x)\n            x = self.bn(x)\n            x = self.relu(x)\n            x = self.conv_x(x)\n            x = x.permute(0, 2, 3, 4, 1).contiguous()\n            x = x.view(x.size(0), -1, self.embedding_dim)\n\n        else:\n            x = self.Unet(x)\n            x = self.bn(x)\n            x = self.relu(x)\n            x = (\n                x.unfold(2, 2, 2)\n                .unfold(3, 2, 2)\n                .unfold(4, 2, 2)\n                .contiguous()\n            )\n            x = x.view(x.size(0), x.size(1), -1, 8)\n            x = x.permute(0, 2, 3, 1).contiguous()\n            x = x.view(x.size(0), -1, self.flatten_dim)\n            x = self.linear_encoding(x)\n\n        x = self.position_encoding(x)\n        x = self.pe_dropout(x)\n\n        # apply transformer\n        x, intmd_x = self.transformer(x)\n        x = self.pre_head_ln(x)\n\n        return x1_1, x2_1, x3_1, x, intmd_x\n\n\n    def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):\n\n        x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x)\n\n        decoder_output = self.decode(x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers)\n\n        if auxillary_output_layers is not None:\n            auxillary_outputs = {}\n            for i in auxillary_output_layers:\n                val = str(2 * i - 1)\n                _key = 'Z' + str(i)\n                auxillary_outputs[_key] = intmd_encoder_outputs[val]\n\n            return decoder_output\n\n        return decoder_output\n\n    # def _get_padding(self, padding_type, kernel_size):\n    #     assert padding_type in ['SAME', 'VALID']\n    #     if padding_type == 'SAME':\n    #         _list = [(k - 1) // 2 for k in kernel_size]\n    #         return tuple(_list)\n    #     return tuple(0 for _ in kernel_size)\n\n    def _reshape_output(self, x):\n        x = x.view(\n            x.size(0),\n            int(self.img_dim[0] / self.patch_dim),\n            int(self.img_dim[1] / self.patch_dim),\n            int(self.img_dim[2] / self.patch_dim),\n            self.embedding_dim,\n        )\n        x = x.permute(0, 4, 1, 2, 3).contiguous()\n\n        return x\n\n\nclass BTS(TransformerBTS):\n    def __init__(self,\n                 in_channels,\n                 num_classes,\n                 img_shape=(128, 128, 128),\n                 patch_dim=8,\n                 embedding_dim=512,\n                 num_heads=8,\n                 num_layers=4,\n                 hidden_dim=4096,\n                 dropout_rate=0.1,\n                 attn_dropout_rate=0.1,\n                 conv_patch_representation=True,\n                 positional_encoding_type=\"learned\"):\n        super(BTS, self).__init__(\n            img_dim=img_shape,\n            patch_dim=patch_dim,\n            num_channels=in_channels,\n            embedding_dim=embedding_dim,\n            num_heads=num_heads,\n            num_layers=num_layers,\n            hidden_dim=hidden_dim,\n            dropout_rate=dropout_rate,\n            attn_dropout_rate=attn_dropout_rate,\n            conv_patch_representation=conv_patch_representation,\n            positional_encoding_type=positional_encoding_type,\n        )\n\n        self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim)\n        self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4)\n\n        self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8)\n        self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8)\n\n        self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16)\n        self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16)\n\n        self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32)\n        self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32)\n\n        self.endconv = nn.Conv3d(self.embedding_dim // 32, num_classes, kernel_size=1)\n\n    def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):\n\n        assert intmd_layers is not None, \"pass the intermediate layers for MLA\"\n        encoder_outputs = {}\n        all_keys = []\n        for i in intmd_layers:\n            val = str(2 * i - 1)\n            _key = 'Z' + str(i)\n            all_keys.append(_key)\n            encoder_outputs[_key] = intmd_x[val]\n        all_keys.reverse()\n\n        x8 = encoder_outputs[all_keys[0]]\n        x8 = self._reshape_output(x8)\n        x8 = self.Enblock8_1(x8)\n        x8 = self.Enblock8_2(x8)\n\n        y4 = self.DeUp4(x8, x3_1)  # (1, 64, 32, 32, 32)\n        y4 = self.DeBlock4(y4)\n\n        y3 = self.DeUp3(y4, x2_1)  # (1, 32, 64, 64, 64)\n        y3 = self.DeBlock3(y3)\n\n        y2 = self.DeUp2(y3, x1_1)  # (1, 16, 128, 128, 128)\n        y2 = self.DeBlock2(y2)\n\n        y = self.endconv(y2)      # (1, 4, 128, 128, 128)\n        return y\n\nclass EnBlock1(nn.Module):\n    def __init__(self, in_channels, ):\n        super(EnBlock1, self).__init__()\n\n        self.bn1 = nn.BatchNorm3d(in_channels // 4)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.bn2 = nn.BatchNorm3d(in_channels // 4)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1)\n        self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        x1 = self.conv1(x)\n        x1 = self.bn1(x1)\n        x1 = self.relu1(x1)\n        x1 = self.conv2(x1)\n        x1 = self.bn2(x1)\n        x1 = self.relu2(x1)\n\n        return x1\n\nclass EnBlock2(nn.Module):\n    def __init__(self, in_channels):\n        super(EnBlock2, self).__init__()\n\n        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n        self.bn1 = nn.BatchNorm3d(in_channels)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.bn2 = nn.BatchNorm3d(in_channels)\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n\n    def forward(self, x):\n        x1 = self.conv1(x)\n        x1 = self.bn1(x1)\n        x1 = self.relu1(x1)\n        x1 = self.conv2(x1)\n        x1 = self.bn2(x1)\n        x1 = self.relu2(x1)\n        x1 = x1 + x\n\n        return x1\n\nclass DeUp_Cat(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(DeUp_Cat, self).__init__()\n        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)\n        self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)\n        self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1)\n\n    def forward(self, x, prev):\n        x1 = self.conv1(x)\n        y = self.conv2(x1)\n        # y = y + prev\n        y = torch.cat((prev, y), dim=1)\n        y = self.conv3(y)\n        return y\n\nclass DeBlock(nn.Module):\n    def __init__(self, in_channels):\n        super(DeBlock, self).__init__()\n\n        self.bn1 = nn.BatchNorm3d(in_channels)\n        self.relu1 = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)\n        self.bn2 = nn.BatchNorm3d(in_channels)\n        self.relu2 = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        x1 = self.conv1(x)\n        x1 = self.bn1(x1)\n        x1 = self.relu1(x1)\n        x1 = self.conv2(x1)\n        x1 = self.bn2(x1)\n        x1 = self.relu2(x1)\n        x1 = x1 + x\n\n        return x1\n\n\ndef transbts(in_channels, num_classes, **kwargs):\n    model = BTS(in_channels, num_classes, img_shape=kwargs['img_shape'])\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 64, 96, 64).long()\n#     model = transbts(1, 10, img_shape=(64, 96, 64))\n#     model.train()\n#     input = torch.rand(2, 1, 64, 96, 64)\n#     output = model(input)\n#     loss_train = criterion(output, mask)\n#     loss_train.backward()\n#     output = output.data.cpu().numpy()\n#     print(output.shape)\n#     print(loss_train)"
  },
  {
    "path": "models/networks_3d/unet3d.py",
    "content": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass UNet3D(nn.Module):\n    def __init__(self, in_channels=1, out_channels=3, init_features=64):\n        \"\"\"\n        Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650\n        \"\"\"\n\n        super(UNet3D, self).__init__()\n\n        features = init_features\n        self.encoder1 = UNet3D._block(in_channels, features, name=\"enc1\")\n        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder2 = UNet3D._block(features, features * 2, name=\"enc2\")\n        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder3 = UNet3D._block(features * 2, features * 4, name=\"enc3\")\n        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder4 = UNet3D._block(features * 4, features * 8, name=\"enc4\")\n        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)\n\n        self.bottleneck = UNet3D._block(features * 8, features * 16, name=\"bottleneck\")\n\n        self.upconv4 = nn.ConvTranspose3d(\n            features * 16, features * 8, kernel_size=2, stride=2\n        )\n        self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name=\"dec4\")\n        self.upconv3 = nn.ConvTranspose3d(\n            features * 8, features * 4, kernel_size=2, stride=2\n        )\n        self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name=\"dec3\")\n        self.upconv2 = nn.ConvTranspose3d(\n            features * 4, features * 2, kernel_size=2, stride=2\n        )\n        self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name=\"dec2\")\n        self.upconv1 = nn.ConvTranspose3d(\n            features * 2, features, kernel_size=2, stride=2\n        )\n        self.decoder1 = UNet3D._block(features * 2, features, name=\"dec1\")\n\n        self.conv = nn.Conv3d(\n            in_channels=features, out_channels=out_channels, kernel_size=1\n        )\n\n    def forward(self, x):\n        enc1 = self.encoder1(x)\n        enc2 = self.encoder2(self.pool1(enc1))\n        enc3 = self.encoder3(self.pool2(enc2))\n        enc4 = self.encoder4(self.pool3(enc3))\n\n        bottleneck = self.bottleneck(self.pool4(enc4))\n\n        dec4 = self.upconv4(bottleneck)\n        dec4 = torch.cat((dec4, enc4), dim=1)\n        dec4 = self.decoder4(dec4)\n        dec3 = self.upconv3(dec4)\n        dec3 = torch.cat((dec3, enc3), dim=1)\n        dec3 = self.decoder3(dec3)\n        dec2 = self.upconv2(dec3)\n        dec2 = torch.cat((dec2, enc2), dim=1)\n        dec2 = self.decoder2(dec2)\n        dec1 = self.upconv1(dec2)\n        dec1 = torch.cat((dec1, enc1), dim=1)\n        dec1 = self.decoder1(dec1)\n        outputs = self.conv(dec1)\n        return outputs\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\n\nclass UNet3D_min(nn.Module):\n    def __init__(self, in_channels=1, out_channels=3, init_features=32):\n        \"\"\"\n        Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650\n        \"\"\"\n\n        super(UNet3D_min, self).__init__()\n\n        features = init_features\n        self.encoder1 = UNet3D._block(in_channels, features, name=\"enc1\")\n        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder2 = UNet3D._block(features, features * 2, name=\"enc2\")\n        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder3 = UNet3D._block(features * 2, features * 4, name=\"enc3\")\n        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder4 = UNet3D._block(features * 4, features * 8, name=\"enc4\")\n        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)\n\n        self.bottleneck = UNet3D._block(features * 8, features * 16, name=\"bottleneck\")\n\n        self.upconv4 = nn.ConvTranspose3d(\n            features * 16, features * 8, kernel_size=2, stride=2\n        )\n        self.decoder4 = UNet3D._block((features * 8) * 2 , features * 8, name=\"dec4\")\n        self.upconv3 = nn.ConvTranspose3d(\n            features * 8, features * 4, kernel_size=2, stride=2\n        )\n        self.decoder3 = UNet3D._block((features * 4) * 2, features * 4, name=\"dec3\")\n        self.upconv2 = nn.ConvTranspose3d(\n            features * 4, features * 2, kernel_size=2, stride=2\n        )\n        self.decoder2 = UNet3D._block((features * 2) * 2, features * 2, name=\"dec2\")\n        self.upconv1 = nn.ConvTranspose3d(\n            features * 2, features, kernel_size=2, stride=2\n        )\n        self.decoder1 = UNet3D._block(features * 2, features, name=\"dec1\")\n\n        self.conv = nn.Conv3d(\n            in_channels=features, out_channels=out_channels, kernel_size=1\n        )\n\n    def forward(self, x):\n        enc1 = self.encoder1(x)\n        enc2 = self.encoder2(self.pool1(enc1))\n        enc3 = self.encoder3(self.pool2(enc2))\n        enc4 = self.encoder4(self.pool3(enc3))\n\n        bottleneck = self.bottleneck(self.pool4(enc4))\n\n        dec4 = self.upconv4(bottleneck)\n        dec4 = torch.cat((dec4, enc4), dim=1)\n        dec4 = self.decoder4(dec4)\n        dec3 = self.upconv3(dec4)\n        dec3 = torch.cat((dec3, enc3), dim=1)\n        dec3 = self.decoder3(dec3)\n        dec2 = self.upconv2(dec3)\n        dec2 = torch.cat((dec2, enc2), dim=1)\n        dec2 = self.decoder2(dec2)\n        dec1 = self.upconv1(dec2)\n        dec1 = torch.cat((dec1, enc1), dim=1)\n        dec1 = self.decoder1(dec1)\n        outputs = self.conv(dec1)\n        return outputs\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\ndef unet3d(in_channels, num_classes):\n    model = UNet3D(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef unet3d_min(in_channels, num_classes):\n    model = UNet3D_min(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#     model = unet3d(1,10)\n#     model.eval()\n#     input = torch.rand(2, 1, 128, 128, 128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n"
  },
  {
    "path": "models/networks_3d/unet3d_cct.py",
    "content": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch.distributions.uniform import Uniform\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass FeatureNoise(nn.Module):\n    def __init__(self, uniform_range=0.3):\n        super(FeatureNoise, self).__init__()\n        self.uni_dist = Uniform(-uniform_range, uniform_range)\n\n    def feature_based_noise(self, x):\n        noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)\n        x_noise = x.mul(noise_vector) + x\n        return x_noise\n\n    def forward(self, x):\n        x = self.feature_based_noise(x)\n        return x\n\ndef Dropout(x, p=0.3):\n    x = torch.nn.functional.dropout(x, p)\n    return x\n\ndef FeatureDropout(x):\n    attention = torch.mean(x, dim=1, keepdim=True)\n    max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)\n    threshold = max_val * np.random.uniform(0.7, 0.9)\n    threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention)\n    drop_mask = (attention < threshold).float()\n    x = x.mul(drop_mask)\n    return x\n\nclass Decoder(nn.Module):\n    def __init__(self, features, out_channels):\n        super(Decoder, self).__init__()\n\n        self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2)\n        self.decoder4 = Decoder._block((features * 8) * 2, features * 8, name=\"dec4\")\n        self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)\n        self.decoder3 = Decoder._block((features * 4) * 2, features * 4, name=\"dec3\")\n        self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)\n        self.decoder2 = Decoder._block((features * 2) * 2, features * 2, name=\"dec2\")\n        self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)\n        self.decoder1 = Decoder._block(features * 2, features, name=\"dec1\")\n\n        self.conv = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1)\n\n    def forward(self, x5, x4, x3, x2, x1):\n\n        dec4 = self.upconv4(x5)\n        dec4 = torch.cat((dec4, x4), dim=1)\n        dec4 = self.decoder4(dec4)\n        dec3 = self.upconv3(dec4)\n        dec3 = torch.cat((dec3, x3), dim=1)\n        dec3 = self.decoder3(dec3)\n        dec2 = self.upconv2(dec3)\n        dec2 = torch.cat((dec2, x2), dim=1)\n        dec2 = self.decoder2(dec2)\n        dec1 = self.upconv1(dec2)\n        dec1 = torch.cat((dec1, x1), dim=1)\n        dec1 = self.decoder1(dec1)\n        outputs = self.conv(dec1)\n\n        return outputs\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\nclass UNet3D_CCT(nn.Module):\n    def __init__(self, in_channels=1, out_channels=3, init_features=64):\n        \"\"\"\n        Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650\n        \"\"\"\n\n        super(UNet3D_CCT, self).__init__()\n\n        features = init_features\n        self.encoder1 = UNet3D_CCT._block(in_channels, features, name=\"enc1\")\n        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder2 = UNet3D_CCT._block(features, features * 2, name=\"enc2\")\n        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name=\"enc3\")\n        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name=\"enc4\")\n        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)\n\n        self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name=\"bottleneck\")\n\n        self.main_decoder = Decoder(features, out_channels)\n\n        self.aux_decoder1 = Decoder(features, out_channels)\n        self.aux_decoder2 = Decoder(features, out_channels)\n        self.aux_decoder3 = Decoder(features, out_channels)\n\n\n    def forward(self, x):\n        enc1 = self.encoder1(x)\n        enc2 = self.encoder2(self.pool1(enc1))\n        enc3 = self.encoder3(self.pool2(enc2))\n        enc4 = self.encoder4(self.pool3(enc3))\n\n        bottleneck = self.bottleneck(self.pool4(enc4))\n\n        main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1)\n\n        aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1))\n        aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1))\n        aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1))\n\n        return main_seg, aux_seg1, aux_seg2, aux_seg3\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\nclass UNet3D_CCT_min(nn.Module):\n    def __init__(self, in_channels=1, out_channels=3, init_features=32):\n        \"\"\"\n        Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650\n        \"\"\"\n\n        super(UNet3D_CCT_min, self).__init__()\n\n        features = init_features\n        self.encoder1 = UNet3D_CCT._block(in_channels, features, name=\"enc1\")\n        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder2 = UNet3D_CCT._block(features, features * 2, name=\"enc2\")\n        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder3 = UNet3D_CCT._block(features * 2, features * 4, name=\"enc3\")\n        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder4 = UNet3D_CCT._block(features * 4, features * 8, name=\"enc4\")\n        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)\n\n        self.bottleneck = UNet3D_CCT._block(features * 8, features * 16, name=\"bottleneck\")\n\n        self.main_decoder = Decoder(features, out_channels)\n\n        self.aux_decoder1 = Decoder(features, out_channels)\n        self.aux_decoder2 = Decoder(features, out_channels)\n        self.aux_decoder3 = Decoder(features, out_channels)\n\n\n    def forward(self, x):\n        enc1 = self.encoder1(x)\n        enc2 = self.encoder2(self.pool1(enc1))\n        enc3 = self.encoder3(self.pool2(enc2))\n        enc4 = self.encoder4(self.pool3(enc3))\n\n        bottleneck = self.bottleneck(self.pool4(enc4))\n\n        main_seg = self.main_decoder(bottleneck, enc4, enc3, enc2, enc1)\n\n        aux_seg1 = self.main_decoder(FeatureNoise()(bottleneck), FeatureNoise()(enc4), FeatureNoise()(enc3), FeatureNoise()(enc2), FeatureNoise()(enc1))\n        aux_seg2 = self.main_decoder(Dropout(bottleneck), Dropout(enc4), Dropout(enc3), Dropout(enc2), Dropout(enc1))\n        aux_seg3 = self.main_decoder(FeatureDropout(bottleneck), FeatureDropout(enc4), FeatureDropout(enc3), FeatureDropout(enc2), FeatureDropout(enc1))\n\n        return main_seg, aux_seg1, aux_seg2, aux_seg3\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\ndef unet3d_cct(in_channels, num_classes):\n    model = UNet3D_CCT(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\ndef unet3d_cct_min(in_channels, num_classes):\n    model = UNet3D_CCT_min(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#     model = unet3d_cct(1,10)\n#     model.eval()\n#     input = torch.rand(2, 1, 128, 128, 128)\n#     output, aux_output1, aux_output2, aux_output3 = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n"
  },
  {
    "path": "models/networks_3d/unet3d_dtc.py",
    "content": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\n# from loss.loss_function import segmentation_loss\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass UNet3D_DTC(nn.Module):\n    def __init__(self, in_channels=1, out_channels=3, init_features=64):\n        \"\"\"\n        Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650\n        \"\"\"\n\n        super(UNet3D_DTC, self).__init__()\n\n        features = init_features\n        self.encoder1 = UNet3D_DTC._block(in_channels, features, name=\"enc1\")\n        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder2 = UNet3D_DTC._block(features, features * 2, name=\"enc2\")\n        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder3 = UNet3D_DTC._block(features * 2, features * 4, name=\"enc3\")\n        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.encoder4 = UNet3D_DTC._block(features * 4, features * 8, name=\"enc4\")\n        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)\n\n        self.bottleneck = UNet3D_DTC._block(features * 8, features * 16, name=\"bottleneck\")\n\n        self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2)\n        self.decoder4 = UNet3D_DTC._block((features * 8) * 2 , features * 8, name=\"dec4\")\n        self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)\n        self.decoder3 = UNet3D_DTC._block((features * 4) * 2, features * 4, name=\"dec3\")\n        self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)\n        self.decoder2 = UNet3D_DTC._block((features * 2) * 2, features * 2, name=\"dec2\")\n        self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)\n        self.decoder1 = UNet3D_DTC._block(features * 2, features, name=\"dec1\")\n\n        self.out_sdf = nn.Sequential(\n            nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1),\n            nn.Tanh()\n        )\n        self.out_seg = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1)\n\n\n    def forward(self, x):\n        enc1 = self.encoder1(x)\n        enc2 = self.encoder2(self.pool1(enc1))\n        enc3 = self.encoder3(self.pool2(enc2))\n        enc4 = self.encoder4(self.pool3(enc3))\n\n        bottleneck = self.bottleneck(self.pool4(enc4))\n\n        dec4 = self.upconv4(bottleneck)\n        dec4 = torch.cat((dec4, enc4), dim=1)\n        dec4 = self.decoder4(dec4)\n        dec3 = self.upconv3(dec4)\n        dec3 = torch.cat((dec3, enc3), dim=1)\n        dec3 = self.decoder3(dec3)\n        dec2 = self.upconv2(dec3)\n        dec2 = torch.cat((dec2, enc2), dim=1)\n        dec2 = self.decoder2(dec2)\n        dec1 = self.upconv1(dec2)\n        dec1 = torch.cat((dec1, enc1), dim=1)\n        dec1 = self.decoder1(dec1)\n\n        out_sdf = self.out_sdf(dec1)\n        out_seg = self.out_seg(dec1)\n        return out_sdf, out_seg\n\n    @staticmethod\n    def _block(in_channels, features, name):\n        return nn.Sequential(\n            OrderedDict(\n                [\n                    (\n                        name + \"conv1\",\n                        nn.Conv3d(\n                            in_channels=in_channels,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm1\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu1\", nn.ReLU(inplace=True)),\n                    (\n                        name + \"conv2\",\n                        nn.Conv3d(\n                            in_channels=features,\n                            out_channels=features,\n                            kernel_size=3,\n                            padding=1,\n                            bias=True,\n                        ),\n                    ),\n                    (name + \"norm2\", nn.BatchNorm3d(num_features=features)),\n                    (name + \"relu2\", nn.ReLU(inplace=True)),\n                ]\n            )\n        )\n\n\ndef unet3d_dtc(in_channels, num_classes):\n    model = UNet3D_DTC(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 96, 96, 96).long()\n#     model = unet3d_dtc(1, 10)\n#     model.train()\n#     input1 = torch.rand(2,1,96,96,96)\n#     out_sdf, out_seg = model(input1)\n#     loss_train = criterion(out_sdf, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(out_sdf.data.cpu().numpy().shape)\n#     print(out_seg.data.cpu().numpy().shape)\n#     print(loss_train)\n"
  },
  {
    "path": "models/networks_3d/unet3d_urpc.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass UnetConv3(nn.Module):\n    def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)):\n        super(UnetConv3, self).__init__()\n\n        if is_batchnorm:\n            self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),\n                                       nn.InstanceNorm3d(out_size),\n                                       nn.ReLU(inplace=True),)\n            self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),\n                                       nn.InstanceNorm3d(out_size),\n                                       nn.ReLU(inplace=True),)\n        else:\n            self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),\n                                       nn.ReLU(inplace=True),)\n            self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),\n                                       nn.ReLU(inplace=True),)\n\n        # initialise the blocks\n        # for m in self.children():\n        #     init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs):\n        outputs = self.conv1(inputs)\n        outputs = self.conv2(outputs)\n        return outputs\n\nclass UnetUp3(nn.Module):\n    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):\n        super(UnetUp3, self).__init__()\n        if is_deconv:\n            self.conv = UnetConv3(in_size, out_size, is_batchnorm)\n            self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0))\n        else:\n            self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm)\n            self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear', align_corners=True)\n\n        # initialise the blocks\n        # for m in self.children():\n        #     if m.__class__.__name__.find('UnetConv3') != -1: continue\n        #     init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs1, inputs2):\n        outputs2 = self.up(inputs2)\n        offset = outputs2.size()[2] - inputs1.size()[2]\n        padding = 2 * [offset // 2, offset // 2, 0]\n        outputs1 = F.pad(inputs1, padding)\n        return self.conv(torch.cat([outputs1, outputs2], 1))\n\n\nclass UnetUp3_CT(nn.Module):\n    def __init__(self, in_size, out_size, is_batchnorm=True):\n        super(UnetUp3_CT, self).__init__()\n        self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))\n        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=True)\n\n        # initialise the blocks\n        # for m in self.children():\n        #     if m.__class__.__name__.find('UnetConv3') != -1: continue\n        #     init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs1, inputs2):\n        outputs2 = self.up(inputs2)\n        offset = outputs2.size()[2] - inputs1.size()[2]\n        padding = 2 * [offset // 2, offset // 2, 0]\n        outputs1 = F.pad(inputs1, padding)\n        return self.conv(torch.cat([outputs1, outputs2], 1))\n\n\nclass UnetDsv3(nn.Module):\n    def __init__(self, in_size, out_size, scale_factor):\n        super(UnetDsv3, self).__init__()\n        self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0),\n                                 nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True), )\n\n    def forward(self, input):\n        return self.dsv(input)\n\nclass unet_3D_dv_semi(nn.Module):\n\n    def __init__(self, in_channels=3, n_classes=21, feature_scale=4, is_deconv=True, is_batchnorm=True):\n        super(unet_3D_dv_semi, self).__init__()\n        self.is_deconv = is_deconv\n        self.in_channels = in_channels\n        self.is_batchnorm = is_batchnorm\n        self.feature_scale = feature_scale\n\n        filters = [64, 128, 256, 512, 1024]\n        filters = [int(x / self.feature_scale) for x in filters]\n\n        # downsampling\n        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(\n            3, 3, 3), padding_size=(1, 1, 1))\n        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))\n\n        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(\n            3, 3, 3), padding_size=(1, 1, 1))\n        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))\n\n        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(\n            3, 3, 3), padding_size=(1, 1, 1))\n        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))\n\n        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(\n            3, 3, 3), padding_size=(1, 1, 1))\n        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))\n\n        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(\n            3, 3, 3), padding_size=(1, 1, 1))\n\n        # upsampling\n        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)\n        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)\n        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)\n        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)\n\n        # deep supervision\n        self.dsv4 = UnetDsv3(\n            in_size=filters[3], out_size=n_classes, scale_factor=8)\n        self.dsv3 = UnetDsv3(\n            in_size=filters[2], out_size=n_classes, scale_factor=4)\n        self.dsv2 = UnetDsv3(\n            in_size=filters[1], out_size=n_classes, scale_factor=2)\n        self.dsv1 = nn.Conv3d(\n            in_channels=filters[0], out_channels=n_classes, kernel_size=1)\n\n        self.dropout1 = nn.Dropout3d(p=0.5)\n        self.dropout2 = nn.Dropout3d(p=0.3)\n        self.dropout3 = nn.Dropout3d(p=0.2)\n        self.dropout4 = nn.Dropout3d(p=0.1)\n\n        # initialise weights\n        # for m in self.modules():\n        #     if isinstance(m, nn.Conv3d):\n        #         init_weights(m, init_type='kaiming')\n        #     elif isinstance(m, nn.BatchNorm3d):\n        #         init_weights(m, init_type='kaiming')\n\n    def forward(self, inputs):\n        conv1 = self.conv1(inputs)\n        maxpool1 = self.maxpool1(conv1)\n\n        conv2 = self.conv2(maxpool1)\n        maxpool2 = self.maxpool2(conv2)\n\n        conv3 = self.conv3(maxpool2)\n        maxpool3 = self.maxpool3(conv3)\n\n        conv4 = self.conv4(maxpool3)\n        maxpool4 = self.maxpool4(conv4)\n\n        center = self.center(maxpool4)\n\n        up4 = self.up_concat4(conv4, center)\n        up4 = self.dropout1(up4)\n\n        up3 = self.up_concat3(conv3, up4)\n        up3 = self.dropout2(up3)\n\n        up2 = self.up_concat2(conv2, up3)\n        up2 = self.dropout3(up2)\n\n        up1 = self.up_concat1(conv1, up2)\n        up1 = self.dropout4(up1)\n\n        # Deep Supervision\n        dsv4 = self.dsv4(up4)\n        dsv3 = self.dsv3(up3)\n        dsv2 = self.dsv2(up2)\n        dsv1 = self.dsv1(up1)\n\n        return dsv1, dsv2, dsv3, dsv4\n\n    @staticmethod\n    def apply_argmax_softmax(pred):\n        log_p = F.softmax(pred, dim=1)\n\n        return log_p\n\ndef unet3d_urpc(in_channels, num_classes):\n    model = unet_3D_dv_semi(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n# if __name__ == '__main__':\n#     model = unet3d_urpc(1,10)\n#     model.eval()\n#     input = torch.rand(2, 1, 128, 128, 128)\n#     output, output2, output3, output4 = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)"
  },
  {
    "path": "models/networks_3d/unetr.py",
    "content": "import copy\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\nclass SingleDeconv3DBlock(nn.Module):\n    def __init__(self, in_planes, out_planes):\n        super().__init__()\n        self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass SingleConv3DBlock(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size):\n        super().__init__()\n        self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,\n                               padding=((kernel_size - 1) // 2))\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Conv3DBlock(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size=3):\n        super().__init__()\n        self.block = nn.Sequential(\n            SingleConv3DBlock(in_planes, out_planes, kernel_size),\n            nn.BatchNorm3d(out_planes),\n            nn.ReLU(True)\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass Deconv3DBlock(nn.Module):\n    def __init__(self, in_planes, out_planes, kernel_size=3):\n        super().__init__()\n        self.block = nn.Sequential(\n            SingleDeconv3DBlock(in_planes, out_planes),\n            SingleConv3DBlock(out_planes, out_planes, kernel_size),\n            nn.BatchNorm3d(out_planes),\n            nn.ReLU(True)\n        )\n\n    def forward(self, x):\n        return self.block(x)\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, num_heads, embed_dim, dropout):\n        super().__init__()\n        self.num_attention_heads = num_heads\n        self.attention_head_size = int(embed_dim / num_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(embed_dim, self.all_head_size)\n        self.key = nn.Linear(embed_dim, self.all_head_size)\n        self.value = nn.Linear(embed_dim, self.all_head_size)\n\n        self.out = nn.Linear(embed_dim, embed_dim)\n        self.attn_dropout = nn.Dropout(dropout)\n        self.proj_dropout = nn.Dropout(dropout)\n\n        self.softmax = nn.Softmax(dim=-1)\n\n        self.vis = False\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(self, hidden_states):\n        mixed_query_layer = self.query(hidden_states)\n        mixed_key_layer = self.key(hidden_states)\n        mixed_value_layer = self.value(hidden_states)\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n        key_layer = self.transpose_for_scores(mixed_key_layer)\n        value_layer = self.transpose_for_scores(mixed_value_layer)\n\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        attention_probs = self.softmax(attention_scores)\n        # weights = attention_probs if self.vis else None\n        attention_probs = self.attn_dropout(attention_probs)\n\n        context_layer = torch.matmul(attention_probs, value_layer)\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n        attention_output = self.out(context_layer)\n        attention_output = self.proj_dropout(attention_output)\n        # return attention_output, weights\n        return attention_output\n\n\n# class Mlp(nn.Module):\n#     def __init__(self, in_features, act_layer=nn.GELU, drop=0.):\n#         super().__init__()\n#         self.fc1 = nn.Linear(in_features, in_features)\n#         self.act = act_layer()\n#         self.drop = nn.Dropout(drop)\n#\n#     def forward(self, x):\n#         x = self.fc1()\n#         x = self.act(x)\n#         x = self.drop(x)\n#         return x\n\n\nclass PositionwiseFeedForward(nn.Module):\n    def __init__(self, d_model=786, d_ff=2048, dropout=0.1):\n        super().__init__()\n        # Torch linears have a `b` by default.\n        self.w_1 = nn.Linear(d_model, d_ff)\n        self.w_2 = nn.Linear(d_ff, d_model)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        return self.w_2(self.dropout(F.relu(self.w_1(x))))\n\n\nclass Embeddings(nn.Module):\n    def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout):\n        super().__init__()\n        self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))\n        self.patch_size = patch_size\n        self.embed_dim = embed_dim\n        self.patch_embeddings = nn.Conv3d(in_channels=input_dim, out_channels=embed_dim,\n                                          kernel_size=patch_size, stride=patch_size)\n        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = self.patch_embeddings(x)\n        x = x.flatten(2)\n        x = x.transpose(-1, -2)\n        embeddings = x + self.position_embeddings\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size):\n        super().__init__()\n        self.attention_norm = nn.LayerNorm(embed_dim, eps=1e-6)\n        self.mlp_norm = nn.LayerNorm(embed_dim, eps=1e-6)\n        self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))\n        self.mlp = PositionwiseFeedForward(embed_dim, 2048)\n        self.attn = SelfAttention(num_heads, embed_dim, dropout)\n\n    def forward(self, x):\n        h = x\n        x = self.attention_norm(x)\n        # x, weights = self.attn(x)\n        x = self.attn(x)\n        x = x + h\n        h = x\n\n        x = self.mlp_norm(x)\n        x = self.mlp(x)\n\n        x = x + h\n        # return x, weights\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_heads, num_layers, dropout, extract_layers):\n        super().__init__()\n        self.embeddings = Embeddings(input_dim, embed_dim, cube_size, patch_size, dropout)\n        self.layer = nn.ModuleList()\n        self.encoder_norm = nn.LayerNorm(embed_dim, eps=1e-6)\n        self.extract_layers = extract_layers\n        for _ in range(num_layers):\n            layer = TransformerBlock(embed_dim, num_heads, dropout, cube_size, patch_size)\n            self.layer.append(copy.deepcopy(layer))\n\n    def forward(self, x):\n        extract_layers = []\n        hidden_states = self.embeddings(x)\n\n        for depth, layer_block in enumerate(self.layer):\n            # hidden_states, _ = layer_block(hidden_states)\n            hidden_states = layer_block(hidden_states)\n            if depth + 1 in self.extract_layers:\n                extract_layers.append(hidden_states)\n\n        return extract_layers\n\n\nclass UNETR(nn.Module):\n    def __init__(self, input_dim=4, output_dim=3, img_shape=(128, 128, 128), embed_dim=768, patch_size=16, num_heads=12,\n                 dropout=0.1):\n        super().__init__()\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.embed_dim = embed_dim\n        self.img_shape = img_shape\n        self.patch_size = patch_size\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.num_layers = 12\n        self.ext_layers = [3, 6, 9, 12]\n\n        self.patch_dim = [int(x / patch_size) for x in img_shape]\n\n        # Transformer Encoder\n        self.transformer = \\\n            Transformer(\n                input_dim,\n                embed_dim,\n                img_shape,\n                patch_size,\n                num_heads,\n                self.num_layers,\n                dropout,\n                self.ext_layers\n            )\n\n        # U-Net Decoder\n        self.decoder0 = \\\n            nn.Sequential(\n                Conv3DBlock(input_dim, 32, 3),\n                Conv3DBlock(32, 64, 3)\n            )\n\n        self.decoder3 = \\\n            nn.Sequential(\n                Deconv3DBlock(embed_dim, 512),\n                Deconv3DBlock(512, 256),\n                Deconv3DBlock(256, 128)\n            )\n\n        self.decoder6 = \\\n            nn.Sequential(\n                Deconv3DBlock(embed_dim, 512),\n                Deconv3DBlock(512, 256),\n            )\n\n        self.decoder9 = \\\n            Deconv3DBlock(embed_dim, 512)\n\n        self.decoder12_upsampler = \\\n            SingleDeconv3DBlock(embed_dim, 512)\n\n        self.decoder9_upsampler = \\\n            nn.Sequential(\n                Conv3DBlock(1024, 512),\n                Conv3DBlock(512, 512),\n                Conv3DBlock(512, 512),\n                SingleDeconv3DBlock(512, 256)\n            )\n\n        self.decoder6_upsampler = \\\n            nn.Sequential(\n                Conv3DBlock(512, 256),\n                Conv3DBlock(256, 256),\n                SingleDeconv3DBlock(256, 128)\n            )\n\n        self.decoder3_upsampler = \\\n            nn.Sequential(\n                Conv3DBlock(256, 128),\n                Conv3DBlock(128, 128),\n                SingleDeconv3DBlock(128, 64)\n            )\n\n        self.decoder0_header = \\\n            nn.Sequential(\n                Conv3DBlock(128, 64),\n                Conv3DBlock(64, 64),\n                SingleConv3DBlock(64, output_dim, 1)\n            )\n\n    def forward(self, x):\n        z = self.transformer(x)\n        z0, z3, z6, z9, z12 = x, *z\n        z3 = z3.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)\n        z6 = z6.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)\n        z9 = z9.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)\n        z12 = z12.transpose(-1, -2).view(-1, self.embed_dim, *self.patch_dim)\n\n        z12 = self.decoder12_upsampler(z12)\n        z9 = self.decoder9(z9)\n        z9 = self.decoder9_upsampler(torch.cat([z9, z12], dim=1))\n        z6 = self.decoder6(z6)\n        z6 = self.decoder6_upsampler(torch.cat([z6, z9], dim=1))\n        z3 = self.decoder3(z3)\n        z3 = self.decoder3_upsampler(torch.cat([z3, z6], dim=1))\n        z0 = self.decoder0(z0)\n        output = self.decoder0_header(torch.cat([z0, z3], dim=1))\n        return output\n\n\ndef unertr(in_channels, num_classes, **kwargs):\n    model = UNETR(in_channels, num_classes, img_shape=kwargs['img_shape'])\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = unertr(1,10, img_shape=(96, 96, 96))\n#     model.eval()\n#     input = torch.rand(2, 1, 96, 96, 96)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n"
  },
  {
    "path": "models/networks_3d/vnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\n\ndef passthrough(x, **kwargs):\n    return x\n\n\ndef ELUCons(elu, nchan):\n    if elu:\n        return nn.ELU(inplace=True)\n    else:\n        return nn.PReLU(nchan)\n\n\nclass LUConv(nn.Module):\n    def __init__(self, nchan, elu):\n        super(LUConv, self).__init__()\n        self.relu1 = ELUCons(elu, nchan)\n        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(nchan)\n\n    def forward(self, x):\n        out = self.relu1(self.bn1(self.conv1(x)))\n        return out\n\n\ndef _make_nConv(nchan, depth, elu):\n    layers = []\n    for _ in range(depth):\n        layers.append(LUConv(nchan, elu))\n    return nn.Sequential(*layers)\n\n\nclass InputTransition(nn.Module):\n    def __init__(self, in_channels, elu):\n        super(InputTransition, self).__init__()\n        self.num_features = 16\n        self.in_channels = in_channels\n\n        self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(self.num_features)\n\n        self.relu1 = ELUCons(elu, self.num_features)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        repeat_rate = int(self.num_features / self.in_channels)\n        out = self.bn1(out)\n        x16 = x.repeat(1, repeat_rate, 1, 1, 1)\n        return self.relu1(torch.add(out, x16))\n\n\nclass DownTransition(nn.Module):\n    def __init__(self, inChans, nConvs, elu, dropout=False):\n        super(DownTransition, self).__init__()\n        outChans = 2 * inChans\n        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)\n        self.bn1 = torch.nn.BatchNorm3d(outChans)\n\n        self.do1 = passthrough\n        self.relu1 = ELUCons(elu, outChans)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x):\n        down = self.relu1(self.bn1(self.down_conv(x)))\n        out = self.do1(down)\n        out = self.ops(out)\n        out = self.relu2(torch.add(out, down))\n        return out\n\n\nclass UpTransition(nn.Module):\n    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):\n        super(UpTransition, self).__init__()\n        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(outChans // 2)\n        self.do1 = passthrough\n        self.do2 = nn.Dropout3d()\n        self.relu1 = ELUCons(elu, outChans // 2)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x, skipx):\n        out = self.do1(x)\n        skipxdo = self.do2(skipx)\n        out = self.relu1(self.bn1(self.up_conv(out)))\n        xcat = torch.cat((out, skipxdo), 1)\n        out = self.ops(xcat)\n        out = self.relu2(torch.add(out, xcat))\n        return out\n\n\nclass OutputTransition(nn.Module):\n    def __init__(self, in_channels, classes, elu):\n        super(OutputTransition, self).__init__()\n        self.classes = classes\n        self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)\n        self.bn1 = torch.nn.BatchNorm3d(classes)\n\n        self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)\n        self.relu1 = ELUCons(elu, classes)\n\n    def forward(self, x):\n        # convolve 32 down to channels as the desired classes\n        out = self.relu1(self.bn1(self.conv1(x)))\n        out = self.conv2(out)\n        return out\n\n\nclass VNet(nn.Module):\n    \"\"\"\n    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797\n    \"\"\"\n\n    def __init__(self, in_channels=1, classes=1, elu=True):\n        super(VNet, self).__init__()\n        self.classes = classes\n        self.in_channels = in_channels\n\n        self.in_tr = InputTransition(in_channels, elu=elu)\n        self.down_tr32 = DownTransition(16, 1, elu)\n        self.down_tr64 = DownTransition(32, 2, elu)\n        self.down_tr128 = DownTransition(64, 3, elu, dropout=False)\n        self.down_tr256 = DownTransition(128, 2, elu, dropout=False)\n        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)\n        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)\n        self.up_tr64 = UpTransition(128, 64, 1, elu)\n        self.up_tr32 = UpTransition(64, 32, 1, elu)\n        self.out_tr = OutputTransition(32, classes, elu)\n\n\n    def forward(self, x):\n        out16 = self.in_tr(x)\n        out32 = self.down_tr32(out16)\n        out64 = self.down_tr64(out32)\n        out128 = self.down_tr128(out64)\n        out256 = self.down_tr256(out128)\n        out = self.up_tr256(out256, out128)\n        out = self.up_tr128(out, out64)\n        out = self.up_tr64(out, out32)\n        out = self.up_tr32(out, out16)\n        out = self.out_tr(out)\n        return out\n\ndef vnet(in_channels, num_classes):\n    model = VNet(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#     model = vnet(1,10)\n#     model.eval()\n#     input = torch.rand(2, 1, 128, 128, 128)\n#     output = model(input)\n#     output = output.data.cpu().numpy()\n#     # print(output)\n#     print(output.shape)\n"
  },
  {
    "path": "models/networks_3d/vnet_cct.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import init\nfrom torch.distributions.uniform import Uniform\n# from loss.loss_function import segmentation_loss\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\nclass FeatureNoise(nn.Module):\n    def __init__(self, uniform_range=0.3):\n        super(FeatureNoise, self).__init__()\n        self.uni_dist = Uniform(-uniform_range, uniform_range)\n\n    def feature_based_noise(self, x):\n        noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)\n        x_noise = x.mul(noise_vector) + x\n        return x_noise\n\n    def forward(self, x):\n        x = self.feature_based_noise(x)\n        return x\n\ndef Dropout(x, p=0.3):\n    x = torch.nn.functional.dropout(x, p)\n    return x\n\ndef FeatureDropout(x):\n    attention = torch.mean(x, dim=1, keepdim=True)\n    max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)\n    threshold = max_val * np.random.uniform(0.7, 0.9)\n    threshold = threshold.view(x.size(0), 1, 1, 1, 1).expand_as(attention)\n    drop_mask = (attention < threshold).float()\n    x = x.mul(drop_mask)\n    return x\n\n\n\ndef passthrough(x, **kwargs):\n    return x\n\n\ndef ELUCons(elu, nchan):\n    if elu:\n        return nn.ELU(inplace=True)\n    else:\n        return nn.PReLU(nchan)\n\n\nclass LUConv(nn.Module):\n    def __init__(self, nchan, elu):\n        super(LUConv, self).__init__()\n        self.relu1 = ELUCons(elu, nchan)\n        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(nchan)\n\n    def forward(self, x):\n        out = self.relu1(self.bn1(self.conv1(x)))\n        return out\n\n\ndef _make_nConv(nchan, depth, elu):\n    layers = []\n    for _ in range(depth):\n        layers.append(LUConv(nchan, elu))\n    return nn.Sequential(*layers)\n\n\nclass InputTransition(nn.Module):\n    def __init__(self, in_channels, elu):\n        super(InputTransition, self).__init__()\n        self.num_features = 16\n        self.in_channels = in_channels\n\n        self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(self.num_features)\n\n        self.relu1 = ELUCons(elu, self.num_features)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        repeat_rate = int(self.num_features / self.in_channels)\n        out = self.bn1(out)\n        x16 = x.repeat(1, repeat_rate, 1, 1, 1)\n        return self.relu1(torch.add(out, x16))\n\n\nclass DownTransition(nn.Module):\n    def __init__(self, inChans, nConvs, elu, dropout=False):\n        super(DownTransition, self).__init__()\n        outChans = 2 * inChans\n        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)\n        self.bn1 = torch.nn.BatchNorm3d(outChans)\n\n        self.do1 = passthrough\n        self.relu1 = ELUCons(elu, outChans)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x):\n        down = self.relu1(self.bn1(self.down_conv(x)))\n        out = self.do1(down)\n        out = self.ops(out)\n        out = self.relu2(torch.add(out, down))\n        return out\n\n\nclass UpTransition(nn.Module):\n    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):\n        super(UpTransition, self).__init__()\n        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(outChans // 2)\n        self.do1 = passthrough\n        self.do2 = nn.Dropout3d()\n        self.relu1 = ELUCons(elu, outChans // 2)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x, skipx):\n        out = self.do1(x)\n        skipxdo = self.do2(skipx)\n        out = self.relu1(self.bn1(self.up_conv(out)))\n        xcat = torch.cat((out, skipxdo), 1)\n        out = self.ops(xcat)\n        out = self.relu2(torch.add(out, xcat))\n        return out\n\n\nclass OutputTransition(nn.Module):\n    def __init__(self, in_channels, classes, elu):\n        super(OutputTransition, self).__init__()\n        self.classes = classes\n        self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)\n        self.bn1 = torch.nn.BatchNorm3d(classes)\n\n        self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)\n        self.relu1 = ELUCons(elu, classes)\n\n    def forward(self, x):\n        # convolve 32 down to channels as the desired classes\n        out = self.relu1(self.bn1(self.conv1(x)))\n        out = self.conv2(out)\n        return out\n\n\nclass Decoder(nn.Module):\n    def __init__(self, out_channels, elu):\n        super(Decoder, self).__init__()\n\n        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)\n        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)\n        self.up_tr64 = UpTransition(128, 64, 1, elu)\n        self.up_tr32 = UpTransition(64, 32, 1, elu)\n        self.out_tr = OutputTransition(32, out_channels, elu)\n\n    def forward(self, out256, out128, out64, out32, out16):\n        out = self.up_tr256(out256, out128)\n        out = self.up_tr128(out, out64)\n        out = self.up_tr64(out, out32)\n        out = self.up_tr32(out, out16)\n        out = self.out_tr(out)\n\n        return out\n\nclass VNet_CCT(nn.Module):\n    \"\"\"\n    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797\n    \"\"\"\n\n    def __init__(self, in_channels=1, classes=1, elu=True):\n        super(VNet_CCT, self).__init__()\n        self.classes = classes\n        self.in_channels = in_channels\n\n        self.in_tr = InputTransition(in_channels, elu=elu)\n        self.down_tr32 = DownTransition(16, 1, elu)\n        self.down_tr64 = DownTransition(32, 2, elu)\n        self.down_tr128 = DownTransition(64, 3, elu, dropout=False)\n        self.down_tr256 = DownTransition(128, 2, elu, dropout=False)\n\n        self.main_decoder = Decoder(classes, elu)\n\n        self.aux_decoder1 = Decoder(classes, elu)\n        self.aux_decoder2 = Decoder(classes, elu)\n        self.aux_decoder3 = Decoder(classes, elu)\n\n\n    def forward(self, x):\n        out16 = self.in_tr(x)\n        out32 = self.down_tr32(out16)\n        out64 = self.down_tr64(out32)\n        out128 = self.down_tr128(out64)\n        out256 = self.down_tr256(out128)\n\n        main_seg = self.main_decoder(out256, out128, out64, out32, out16)\n\n        aux_seg1 = self.main_decoder(FeatureNoise()(out256), FeatureNoise()(out128), FeatureNoise()(out64), FeatureNoise()(out32), FeatureNoise()(out16))\n        aux_seg2 = self.main_decoder(Dropout(out256), Dropout(out128), Dropout(out64), Dropout(out32), Dropout(out16))\n        aux_seg3 = self.main_decoder(FeatureDropout(out256), FeatureDropout(out128), FeatureDropout(out64), FeatureDropout(out32), FeatureDropout(out16))\n\n        return main_seg, aux_seg1, aux_seg2, aux_seg3\n\ndef vnet_cct(in_channels, num_classes):\n    model = VNet_CCT(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 64, 96, 64).long()\n#     model = vnet_cct(1, 10)\n#     model.train()\n#     input = torch.rand(2, 1, 64, 96, 64)\n#     output, output1, output2, output3 = model(input)\n#     loss_train = criterion(output, mask)\n#     loss_train.backward()\n#     output = output.data.cpu().numpy()\n#     print(output.shape)\n#     print(loss_train)\n"
  },
  {
    "path": "models/networks_3d/vnet_dtc.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import init\n# from loss.loss_function import segmentation_loss\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    def init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:\n            init.normal_(m.weight.data, 1.0, gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)\n\n\n\ndef passthrough(x, **kwargs):\n    return x\n\n\ndef ELUCons(elu, nchan):\n    if elu:\n        return nn.ELU(inplace=True)\n    else:\n        return nn.PReLU(nchan)\n\n\nclass LUConv(nn.Module):\n    def __init__(self, nchan, elu):\n        super(LUConv, self).__init__()\n        self.relu1 = ELUCons(elu, nchan)\n        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(nchan)\n\n    def forward(self, x):\n        out = self.relu1(self.bn1(self.conv1(x)))\n        return out\n\n\ndef _make_nConv(nchan, depth, elu):\n    layers = []\n    for _ in range(depth):\n        layers.append(LUConv(nchan, elu))\n    return nn.Sequential(*layers)\n\n\nclass InputTransition(nn.Module):\n    def __init__(self, in_channels, elu):\n        super(InputTransition, self).__init__()\n        self.num_features = 16\n        self.in_channels = in_channels\n\n        self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(self.num_features)\n\n        self.relu1 = ELUCons(elu, self.num_features)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        repeat_rate = int(self.num_features / self.in_channels)\n        out = self.bn1(out)\n        x16 = x.repeat(1, repeat_rate, 1, 1, 1)\n        return self.relu1(torch.add(out, x16))\n\n\nclass DownTransition(nn.Module):\n    def __init__(self, inChans, nConvs, elu, dropout=False):\n        super(DownTransition, self).__init__()\n        outChans = 2 * inChans\n        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)\n        self.bn1 = torch.nn.BatchNorm3d(outChans)\n\n        self.do1 = passthrough\n        self.relu1 = ELUCons(elu, outChans)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x):\n        down = self.relu1(self.bn1(self.down_conv(x)))\n        out = self.do1(down)\n        out = self.ops(out)\n        out = self.relu2(torch.add(out, down))\n        return out\n\n\nclass UpTransition(nn.Module):\n    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):\n        super(UpTransition, self).__init__()\n        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)\n\n        self.bn1 = torch.nn.BatchNorm3d(outChans // 2)\n        self.do1 = passthrough\n        self.do2 = nn.Dropout3d()\n        self.relu1 = ELUCons(elu, outChans // 2)\n        self.relu2 = ELUCons(elu, outChans)\n        if dropout:\n            self.do1 = nn.Dropout3d()\n        self.ops = _make_nConv(outChans, nConvs, elu)\n\n    def forward(self, x, skipx):\n        out = self.do1(x)\n        skipxdo = self.do2(skipx)\n        out = self.relu1(self.bn1(self.up_conv(out)))\n        xcat = torch.cat((out, skipxdo), 1)\n        out = self.ops(xcat)\n        out = self.relu2(torch.add(out, xcat))\n        return out\n\n\nclass OutputTransition(nn.Module):\n    def __init__(self, in_channels, classes, elu):\n        super(OutputTransition, self).__init__()\n        self.classes = classes\n        self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)\n        self.bn1 = torch.nn.BatchNorm3d(classes)\n\n        self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)\n        self.relu1 = ELUCons(elu, classes)\n\n    def forward(self, x):\n        # convolve 32 down to channels as the desired classes\n        out = self.relu1(self.bn1(self.conv1(x)))\n        out = self.conv2(out)\n        return out\n\n\nclass VNet_DTC(nn.Module):\n    \"\"\"\n    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797\n    \"\"\"\n\n    def __init__(self, in_channels=1, classes=1, elu=True):\n        super(VNet_DTC, self).__init__()\n        self.classes = classes\n        self.in_channels = in_channels\n\n        self.in_tr = InputTransition(in_channels, elu=elu)\n        self.down_tr32 = DownTransition(16, 1, elu)\n        self.down_tr64 = DownTransition(32, 2, elu)\n        self.down_tr128 = DownTransition(64, 3, elu, dropout=False)\n        self.down_tr256 = DownTransition(128, 2, elu, dropout=False)\n        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)\n        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)\n        self.up_tr64 = UpTransition(128, 64, 1, elu)\n        self.up_tr32 = UpTransition(64, 32, 1, elu)\n        self.out_tr = OutputTransition(32, 16, elu)\n\n        self.out_sdf = nn.Sequential(\n            nn.Conv3d(16, classes, 1, padding=0),\n            nn.Tanh()\n        )\n        self.out_seg = nn.Conv3d(16, classes, 1, padding=0)\n\n\n    def forward(self, x):\n        out16 = self.in_tr(x)\n        out32 = self.down_tr32(out16)\n        out64 = self.down_tr64(out32)\n        out128 = self.down_tr128(out64)\n        out256 = self.down_tr256(out128)\n        out = self.up_tr256(out256, out128)\n        out = self.up_tr128(out, out64)\n        out = self.up_tr64(out, out32)\n        out = self.up_tr32(out, out16)\n        out = self.out_tr(out)\n\n        out_sdf = self.out_sdf(out)\n        out_seg = self.out_seg(out)\n        return out_sdf, out_seg\n\ndef vnet_dtc(in_channels, num_classes):\n    model = VNet_DTC(in_channels, num_classes)\n    init_weights(model, 'kaiming')\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 96, 96, 96).long()\n#     model = vnet_dtc(1, 10)\n#     model.train()\n#     input1 = torch.rand(2,1,96,96,96)\n#     out_sdf, out_seg = model(input1)\n#     loss_train = criterion(out_sdf, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(out_sdf.data.cpu().numpy().shape)\n#     print(out_seg.data.cpu().numpy().shape)\n#     print(loss_train)\n"
  },
  {
    "path": "models/networks_3d/xnet3d.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch.distributions.uniform import Uniform\nimport numpy as np\n# from loss.loss_function import segmentation_loss\n\n# BN\nBatchNorm3d = nn.InstanceNorm3d\nBN_MOMENTUM = 0.1\n# BN_MOMENTUM = 0.01\n\n# AF\nrelu_inplace = True\nActivationFunction = nn.ReLU\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\ndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)\n\nclass up_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(up_conv, self).__init__()\n        self.up = nn.Sequential(\n            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),\n            nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),\n            BatchNorm3d(ch_out, momentum=BN_MOMENTUM),\n            ActivationFunction(inplace=relu_inplace)\n        )\n\n    def forward(self, x):\n        x = self.up(x)\n        return x\n\nclass down_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(down_conv, self).__init__()\n        self.down = nn.Sequential(\n            nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=2, padding=1, bias=False),\n            BatchNorm3d(ch_out, momentum=BN_MOMENTUM),\n            ActivationFunction(inplace=relu_inplace)\n        )\n    def forward(self, x):\n        x = self.down(x)\n        return x\n\nclass same_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(same_conv, self).__init__()\n        self.same = nn.Sequential(\n            nn.Conv3d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),\n            BatchNorm3d(ch_out, momentum=BN_MOMENTUM),\n            ActivationFunction(inplace=relu_inplace)\n        )\n    def forward(self, x):\n        x = self.same(x)\n        return x\n\nclass transition_conv(nn.Module):\n    def __init__(self, ch_in, ch_out):\n        super(transition_conv, self).__init__()\n        self.transition = nn.Sequential(\n            nn.Conv3d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=False),\n            BatchNorm3d(ch_out, momentum=BN_MOMENTUM),\n            ActivationFunction(inplace=relu_inplace)\n        )\n    def forward(self, x):\n        x = self.transition(x)\n        return x\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n                 base_width=64, dilation=1, norm_layer=None):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = BatchNorm3d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes, momentum=BN_MOMENTUM)\n        self.relu = ActivationFunction(inplace=relu_inplace)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes, momentum=BN_MOMENTUM)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        # out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out = self.bn2(out) + identity\n        out = self.relu(out)\n\n        return out\n\nclass DoubleBasicBlock(nn.Module):\n    def __init__(self, inplanes, planes, downsample=None):\n        super(DoubleBasicBlock, self).__init__()\n\n        self.DBB = nn.Sequential(\n            BasicBlock(inplanes=inplanes, planes=planes, downsample=downsample),\n            BasicBlock(inplanes=planes, planes=planes)\n        )\n\n    def forward(self, x):\n        out = self.DBB(x)\n        return out\n\nclass XNet3D(nn.Module):\n    def __init__(self, in_channels, num_classes):\n        super(XNet3D, self).__init__()\n\n        # l1c, l2c, l3c, l4c = 64, 128, 256, 512\n        # l1c, l2c, l3c, l4c, l5c = 8, 16, 32, 64, 128\n        # l1c, l2c, l3c, l4c, l5c = 16, 32, 64, 128, 256\n        l1c, l2c, l3c, l4c, l5c = 32, 64, 128, 256, 512\n        # branch1\n        # branch1_layer1\n        self.b1_1_1 = nn.Sequential(\n            conv3x3(in_channels, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b1_1_2_down = down_conv(l1c, l2c)\n        self.b1_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM)))\n        self.b1_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch1_layer2\n        self.b1_2_1 = BasicBlock(l2c, l2c)\n        self.b1_2_2_down = down_conv(l2c, l3c)\n        self.b1_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM)))\n        self.b1_2_4_up = up_conv(l2c, l1c)\n        # branch1_layer3\n        self.b1_3_1 = BasicBlock(l3c, l3c)\n        self.b1_3_2_down = down_conv(l3c, l4c)\n        self.b1_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM)))\n        self.b1_3_4_up = up_conv(l3c, l2c)\n        # branch1_layer4\n        self.b1_4_1 = BasicBlock(l4c, l4c)\n        self.b1_4_2_down = down_conv(l4c, l5c)\n        self.b1_4_2 = BasicBlock(l4c, l4c)\n        self.b1_4_3_down = down_conv(l4c, l4c)\n        self.b1_4_3_same = same_conv(l4c, l4c)\n        self.b1_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b1_4_5 = BasicBlock(l4c, l4c)\n        self.b1_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM)))\n        self.b1_4_7_up = up_conv(l4c, l3c)\n        # branch1_layer5\n        self.b1_5_1 = BasicBlock(l5c, l5c)\n        self.b1_5_2_up = up_conv(l5c, l5c)\n        self.b1_5_2_same = same_conv(l5c, l5c)\n        self.b1_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b1_5_4 = BasicBlock(l5c, l5c)\n        self.b1_5_5_up = up_conv(l5c, l4c)\n\n        # branch2\n        # branch2_layer1\n        self.b2_1_1 = nn.Sequential(\n            conv3x3(1, l1c),\n            conv3x3(l1c, l1c),\n            BasicBlock(l1c, l1c)\n        )\n        self.b2_1_2_down = down_conv(l1c, l2c)\n        self.b2_1_3 = BasicBlock(l1c+l1c, l1c, downsample=nn.Sequential(conv1x1(in_planes=l1c+l1c, out_planes=l1c), BatchNorm3d(l1c, momentum=BN_MOMENTUM)))\n        self.b2_1_4 = nn.Conv3d(l1c, num_classes, kernel_size=1, stride=1, padding=0)\n        # branch2_layer2\n        self.b2_2_1 = BasicBlock(l2c, l2c)\n        self.b2_2_2_down = down_conv(l2c, l3c)\n        self.b2_2_3 = BasicBlock(l2c+l2c, l2c, downsample=nn.Sequential(conv1x1(in_planes=l2c+l2c, out_planes=l2c), BatchNorm3d(l2c, momentum=BN_MOMENTUM)))\n        self.b2_2_4_up = up_conv(l2c, l1c)\n        # branch2_layer3\n        self.b2_3_1 = BasicBlock(l3c, l3c)\n        self.b2_3_2_down = down_conv(l3c, l4c)\n        self.b2_3_3 = BasicBlock(l3c+l3c, l3c, downsample=nn.Sequential(conv1x1(in_planes=l3c+l3c, out_planes=l3c), BatchNorm3d(l3c, momentum=BN_MOMENTUM)))\n        self.b2_3_4_up = up_conv(l3c, l2c)\n        # branch2_layer4\n        self.b2_4_1 = BasicBlock(l4c, l4c)\n        self.b2_4_2_down = down_conv(l4c, l5c)\n        self.b2_4_2 = BasicBlock(l4c, l4c)\n        self.b2_4_3_down = down_conv(l4c, l4c)\n        self.b2_4_3_same = same_conv(l4c, l4c)\n        self.b2_4_4_transition = transition_conv(l4c+l5c+l4c, l4c)\n        self.b2_4_5 = BasicBlock(l4c, l4c)\n        self.b2_4_6 = BasicBlock(l4c+l4c, l4c, downsample=nn.Sequential(conv1x1(in_planes=l4c+l4c, out_planes=l4c), BatchNorm3d(l4c, momentum=BN_MOMENTUM)))\n        self.b2_4_7_up = up_conv(l4c, l3c)\n        # branch2_layer5\n        self.b2_5_1 = BasicBlock(l5c, l5c)\n        self.b2_5_2_up = up_conv(l5c, l5c)\n        self.b2_5_2_same = same_conv(l5c, l5c)\n        self.b2_5_3_transition = transition_conv(l5c+l5c+l4c, l5c)\n        self.b2_5_4 = BasicBlock(l5c, l5c)\n        self.b2_5_5_up = up_conv(l5c, l4c)\n\n        # initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv3d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, nn.BatchNorm3d):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABNSync):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            # elif isinstance(m, InPlaceABN):\n            #     nn.init.constant_(m.weight, 1)\n            #     nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def forward(self, input1, input2):\n        # code\n        # branch1\n        x1_1 = self.b1_1_1(input1)\n\n        x1_2 = self.b1_1_2_down(x1_1)\n        x1_2 = self.b1_2_1(x1_2)\n\n        x1_3 = self.b1_2_2_down(x1_2)\n        x1_3 = self.b1_3_1(x1_3)\n\n        x1_4_1 = self.b1_3_2_down(x1_3)\n        x1_4_1 = self.b1_4_1(x1_4_1)\n        x1_4_2 = self.b1_4_2(x1_4_1)\n        x1_4_3_down = self.b1_4_3_down(x1_4_2)\n        x1_4_3_same = self.b1_4_3_same(x1_4_2)\n\n        x1_5_1 = self.b1_4_2_down(x1_4_1)\n        x1_5_1 = self.b1_5_1(x1_5_1)\n        x1_5_2_up = self.b1_5_2_up(x1_5_1)\n        x1_5_2_same = self.b1_5_2_same(x1_5_1)\n        # branch2\n        x2_1 = self.b2_1_1(input2)\n\n        x2_2 = self.b2_1_2_down(x2_1)\n        x2_2 = self.b2_2_1(x2_2)\n\n        x2_3 = self.b2_2_2_down(x2_2)\n        x2_3 = self.b2_3_1(x2_3)\n\n        x2_4_1 = self.b2_3_2_down(x2_3)\n        x2_4_1 = self.b2_4_1(x2_4_1)\n        x2_4_2 = self.b2_4_2(x2_4_1)\n        x2_4_3_down = self.b2_4_3_down(x2_4_2)\n        x2_4_3_same = self.b2_4_3_same(x2_4_2)\n\n        x2_5_1 = self.b2_4_2_down(x2_4_1)\n        x2_5_1 = self.b2_5_1(x2_5_1)\n        x2_5_2_up = self.b2_5_2_up(x2_5_1)\n        x2_5_2_same = self.b2_5_2_same(x2_5_1)\n\n        # merge\n        # branch1\n        x1_5_3 = torch.cat((x1_5_2_same, x2_5_2_same, x2_4_3_down), dim=1)\n        x1_5_3 = self.b1_5_3_transition(x1_5_3)\n        x1_5_3 = self.b1_5_4(x1_5_3)\n        x1_5_3 = self.b1_5_5_up(x1_5_3)\n\n        x1_4_4 = torch.cat((x1_4_3_same, x2_4_3_same, x2_5_2_up), dim=1)\n        x1_4_4 = self.b1_4_4_transition(x1_4_4)\n        x1_4_4 = self.b1_4_5(x1_4_4)\n        x1_4_4 = torch.cat((x1_4_4, x1_5_3), dim=1)\n        x1_4_4 = self.b1_4_6(x1_4_4)\n        x1_4_4 = self.b1_4_7_up(x1_4_4)\n        # branch2\n        x2_5_3 = torch.cat((x2_5_2_same, x1_5_2_same, x1_4_3_down), dim=1)\n        x2_5_3 = self.b2_5_3_transition(x2_5_3)\n        x2_5_3 = self.b2_5_4(x2_5_3)\n        x2_5_3 = self.b2_5_5_up(x2_5_3)\n\n        x2_4_4 = torch.cat((x2_4_3_same, x1_4_3_same, x1_5_2_up), dim=1)\n        x2_4_4 = self.b2_4_4_transition(x2_4_4)\n        x2_4_4 = self.b2_4_5(x2_4_4)\n        x2_4_4 = torch.cat((x2_4_4, x2_5_3), dim=1)\n        x2_4_4 = self.b2_4_6(x2_4_4)\n        x2_4_4 = self.b2_4_7_up(x2_4_4)\n\n        # decode\n        # branch1\n        x1_3 = torch.cat((x1_3, x1_4_4), dim=1)\n        x1_3 = self.b1_3_3(x1_3)\n        x1_3 = self.b1_3_4_up(x1_3)\n\n        x1_2 = torch.cat((x1_2, x1_3), dim=1)\n        x1_2 = self.b1_2_3(x1_2)\n        x1_2 = self.b1_2_4_up(x1_2)\n\n        x1_1 = torch.cat((x1_1, x1_2), dim=1)\n        x1_1 = self.b1_1_3(x1_1)\n        x1_1 = self.b1_1_4(x1_1)\n        # branch2\n        x2_3 = torch.cat((x2_3, x2_4_4), dim=1)\n        x2_3 = self.b2_3_3(x2_3)\n        x2_3 = self.b2_3_4_up(x2_3)\n\n        x2_2 = torch.cat((x2_2, x2_3), dim=1)\n        x2_2 = self.b2_2_3(x2_2)\n        x2_2 = self.b2_2_4_up(x2_2)\n\n        x2_1 = torch.cat((x2_1, x2_2), dim=1)\n        x2_1 = self.b2_1_3(x2_1)\n        x2_1 = self.b2_1_4(x2_1)\n\n        return x1_1, x2_1\n\ndef xnet3d(in_channels, num_classes):\n    model = XNet3D(in_channels, num_classes)\n    return model\n\n\n# if __name__ == '__main__':\n#\n#     criterion = segmentation_loss('dice', False)\n#     mask = torch.ones(2, 96, 96, 96).long()\n#     model = XNet3D(1, 10)\n#     model.train()\n#     input1 = torch.rand(2,1,96,96,96)\n#     input2 = torch.rand(2,1,96,96,96)\n#     x1_1_main, x1_1_aux1, x1_1_aux2, x1_1_aux3, x2_1_main, x2_1_aux1, x2_1_aux2, x2_1_aux3 = model(input1, input2)\n#     loss_train = criterion(x1_1_main, mask)\n#     loss_train.backward()\n#     # print(output)\n#     print(x1_1_main.data.cpu().numpy().shape)\n#     print(x2_1_main.data.cpu().numpy().shape)\n#     print(loss_train)\n"
  },
  {
    "path": "requirements.txt",
    "content": "albumentations==0.5.2\neinops==0.4.1\nMedPy==0.4.0\nnumpy==1.20.2\nopencv_python==4.2.0.34\nopencv_python_headless==4.5.1.48\nPillow==8.0.0\nPyWavelets==1.1.1\nscikit_image==0.18.1\nscikit_learn==1.0.1\nscipy==1.4.1\nSimpleITK==2.1.0\ntimm==0.6.7\ntorch==1.8.0+cu111\ntorchio==0.18.53\ntorchvision==0.9.0+cu111\ntqdm==4.65.0\nvisdom==0.1.8.9\n"
  },
  {
    "path": "test.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.train_test_config.train_test_config import print_test_eval, save_test_2d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/GlaS/best_kiunet_Jc_0.7779.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--if_mask', default=True)\n    parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400')\n    parser.add_argument('-ds', '--deep_supervision', default=False)\n    parser.add_argument('-b', '--batch_size', default=4, type=int)\n    parser.add_argument('-n', '--network', default='kiunet')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    # print(path_seg_results)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    # Dataset\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None\n    )\n\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n\n    num_batches = {'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    with torch.no_grad():\n        model.eval()\n\n        for i, data in enumerate(dataloaders['val']):\n            inputs_test = data['image']\n            inputs_test = Variable(inputs_test.cuda(non_blocking=True))\n            name_test = data['ID']\n            if args.if_mask:\n                mask_test = data['mask']\n                mask_test = Variable(mask_test.cuda(non_blocking=True))\n\n            outputs_test = model(inputs_test)\n            if args.deep_supervision:\n                outputs_test = outputs_test[0]\n\n            if args.if_mask:\n                if i == 0:\n                    score_list_test = outputs_test\n                    name_list_test = name_test\n                    mask_list_test = mask_test\n                else:\n                # elif 0 < i <= num_batches['val'] / 16:\n                    score_list_test = torch.cat((score_list_test, outputs_test), dim=0)\n                    name_list_test = np.append(name_list_test, name_test, axis=0)\n                    mask_list_test = torch.cat((mask_list_test, mask_test), dim=0)\n                torch.cuda.empty_cache()\n            else:\n                save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE'])\n                torch.cuda.empty_cache()\n\n        if args.if_mask:\n            score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_test, score_list_test)\n            score_list_test = torch.cat(score_gather_list_test, dim=0)\n\n            mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_test, mask_list_test)\n            mask_list_test = torch.cat(mask_gather_list_test, dim=0)\n\n            name_gather_list_test = [None for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather_object(name_gather_list_test, name_list_test)\n            name_list_test = np.concatenate(name_gather_list_test, axis=0)\n\n        if args.if_mask and rank == args.rank_index:\n            print('=' * print_num)\n            test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus)\n            save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE'])\n            torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "test_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_3d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom config.train_test_config.train_test_config import save_test_3d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_result1_Jc_0.7677.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--threshold', default=None)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('--patch_overlap', default=(56, 56, 16))\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-n', '--network', default='unet3d_min')\n    parser.add_argument('-ds', '--deep_supervision', default=False)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['test'],\n    )\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    for i, subject in enumerate(dataset_val.dataset_1):\n\n        grid_sampler = tio.inference.GridSampler(\n            subject=subject,\n            patch_size=args.patch_size,\n            patch_overlap=args.patch_overlap\n        )\n\n        # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)\n\n        dataloaders = dict()\n        dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)\n        # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n        aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')\n\n        with torch.no_grad():\n            model.eval()\n\n            for data in dataloaders['test']:\n\n                inputs_test = Variable(data['image'][tio.DATA].cuda())\n                location_test = data[tio.LOCATION]\n\n                outputs_test = model(inputs_test)\n                if args.deep_supervision:\n                    outputs_test = outputs_test[0]\n\n                aggregator.add_batch(outputs_test, location_test)\n\n        outputs_tensor = aggregator.get_output_tensor()\n        save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])\n\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "test_ConResNet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_3d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_iit_conresnet\nfrom config.train_test_config.train_test_config import save_test_3d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/sup/LiTS/best_conresnet_Jc_0.8545.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--input2', default='image_res')\n    parser.add_argument('--threshold', default=None)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('--patch_overlap', default=(56, 56, 16))\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-n', '--network', default='conresnet')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n    dataset_val = dataset_iit_conresnet(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['test'],\n    )\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    for i, subject in enumerate(dataset_val.dataset_1):\n\n        grid_sampler = tio.inference.GridSampler(\n            subject=subject,\n            patch_size=args.patch_size,\n            patch_overlap=args.patch_overlap\n        )\n\n        # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)\n\n        dataloaders = dict()\n        dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)\n        # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n        aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')\n\n        with torch.no_grad():\n            model.eval()\n\n            for data in dataloaders['test']:\n\n                inputs_test = Variable(data['image'][tio.DATA].cuda())\n                inputs_test_2 = Variable(data['image2'][tio.DATA].cuda())\n                location_test = data[tio.LOCATION]\n\n                outputs_test = model(inputs_test, inputs_test_2)\n                aggregator.add_batch(outputs_test[0], location_test)\n\n        outputs_tensor = aggregator.get_output_tensor()\n        save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])\n\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "test_DTC.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_3d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it_dtc\nfrom config.train_test_config.train_test_config import save_test_3d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi/LiTS/best_DTC_Jc_0.7594.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--threshold', default=None)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('--patch_overlap', default=(56, 56, 16))\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-n', '--network', default='vnet_dtc')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n    dataset_val = dataset_it_dtc(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        num_classes=cfg['NUM_CLASSES'],\n        transform_1=data_transform['test'],\n    )\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    for i, subject in enumerate(dataset_val.dataset_1):\n\n        grid_sampler = tio.inference.GridSampler(\n            subject=subject,\n            patch_size=args.patch_size,\n            patch_overlap=args.patch_overlap\n        )\n\n        # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)\n\n        dataloaders = dict()\n        dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)\n        # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n        aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')\n\n        with torch.no_grad():\n            model.eval()\n\n            for data in dataloaders['test']:\n\n                inputs_test = Variable(data['image'][tio.DATA].cuda())\n                location_test = data[tio.LOCATION]\n\n                outputs_test_sdf, outputs_test_seg = model(inputs_test)\n                aggregator.add_batch(outputs_test_seg, location_test)\n\n        outputs_tensor = aggregator.get_output_tensor()\n        save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])\n\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "test_xnet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_iitnn\nfrom config.train_test_config.train_test_config import print_test_eval, save_test_2d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/GlaS/best_result2_Jc_0.7898.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--if_mask', default=True)\n    parser.add_argument('--threshold', default=0.5400, help='0.5600, 5400')\n    parser.add_argument('--if_cct', default=False)\n    parser.add_argument('--result', default='result2', help='result1, result2')\n    parser.add_argument('-n', '--network', default='xnet')\n    parser.add_argument('-b', '--batch_size', default=8, type=int)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # Dataset\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    if args.input2 == 'image':\n        input2_mean = 'MEAN'\n        input2_std = 'STD'\n    else:\n        input2_mean = 'MEAN_' + args.input2\n        input2_std = 'STD_' + args.input2\n\n    data_transforms = data_transform_2d()\n    data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n    data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])\n\n\n    dataset_val = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n\n    num_batches = {'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    with torch.no_grad():\n        model.eval()\n\n        for i, data in enumerate(dataloaders['val']):\n\n            inputs_test = Variable(data['image'].cuda(non_blocking=True))\n            inputs_wavelet_test = Variable(data['image_2'].cuda(non_blocking=True))\n            name_test = data['ID']\n            if args.if_mask:\n                mask_test = Variable(data['mask'].cuda(non_blocking=True))\n\n            if args.if_cct:\n                outputs_test1, outputs_test1_aux1, outputs_test1_aux2, outputs_test1_aux3, outputs_test2, outputs_test2_aux1, outputs_test2_aux2, outputs_test2_aux3 = model(inputs_test, inputs_wavelet_test)\n            else:\n                outputs_test1, outputs_test2 = model(inputs_test, inputs_wavelet_test)\n            if args.result == 'result1':\n                outputs_test = outputs_test1\n            else:\n                outputs_test = outputs_test2\n\n            if args.if_mask:\n                if i == 0:\n                    score_list_test = outputs_test\n                    name_list_test = name_test\n                    mask_list_test = mask_test\n                else:\n                # elif 0 < i <= num_batches['val'] / 16:\n                    score_list_test = torch.cat((score_list_test, outputs_test), dim=0)\n                    name_list_test = np.append(name_list_test, name_test, axis=0)\n                    mask_list_test = torch.cat((mask_list_test, mask_test), dim=0)\n                torch.cuda.empty_cache()\n            else:\n                save_test_2d(cfg['NUM_CLASSES'], outputs_test, name_test, args.threshold, path_seg_results, cfg['PALETTE'])\n                torch.cuda.empty_cache()\n\n        if args.if_mask:\n            score_gather_list_test = [torch.zeros_like(score_list_test) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_test, score_list_test)\n            score_list_test = torch.cat(score_gather_list_test, dim=0)\n\n            mask_gather_list_test = [torch.zeros_like(mask_list_test) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_test, mask_list_test)\n            mask_list_test = torch.cat(mask_gather_list_test, dim=0)\n\n            name_gather_list_test = [None for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather_object(name_gather_list_test, name_list_test)\n            name_list_test = np.concatenate(name_gather_list_test, axis=0)\n\n        if args.if_mask and rank == args.rank_index:\n            print('=' * print_num)\n            test_eval_list = print_test_eval(cfg['NUM_CLASSES'], score_list_test, mask_list_test, print_num_minus)\n            save_test_2d(cfg['NUM_CLASSES'], score_list_test, name_list_test, test_eval_list[0], path_seg_results, cfg['PALETTE'])\n            torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "test_xnet3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_3d\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_iit\nfrom config.train_test_config.train_test_config import save_test_3d\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('-p', '--path_model', default='/mnt/data1/XNet/pretrained_model/semi_xnet/LiTS/best_result1_Jc_0.7794.pth')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/test')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--threshold', default=None)\n    parser.add_argument('--result', default='result1', help='result1, result2')\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('--patch_overlap', default=(56, 56, 16))\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-n', '--network', default='xnet3d')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    # Config\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    # Results Save\n    if not os.path.exists(args.path_seg_results) and rank == args.rank_index:\n        os.mkdir(args.path_seg_results)\n    path_seg_results = args.path_seg_results + '/' + str(dataset_name)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(os.path.splitext(os.path.split(args.path_model)[1])[0])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n    dataset_val = dataset_iit(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['test'],\n    )\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n\n    # if rank == args.rank_index:\n    #     state_dict = torch.load(args.path_model, map_location=torch.device(args.local_rank))\n    #     model.load_state_dict(state_dict=state_dict)\n    # model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    state_dict = torch.load(args.path_model)\n    model.load_state_dict(state_dict=state_dict)\n    dist.barrier()\n\n    # Test\n    since = time.time()\n\n    for i, subject in enumerate(dataset_val.dataset_1):\n\n        grid_sampler = tio.inference.GridSampler(\n            subject=subject,\n            patch_size=args.patch_size,\n            patch_overlap=args.patch_overlap\n        )\n\n        # val_sampler = torch.utils.data.distributed.DistributedSampler(grid_sampler, shuffle=False)\n\n        dataloaders = dict()\n        dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16)\n        # dataloaders['test'] = DataLoader(grid_sampler, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=16, sampler=val_sampler)\n        aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')\n\n        with torch.no_grad():\n            model.eval()\n\n            for data in dataloaders['test']:\n\n                inputs_test_1 = Variable(data['image'][tio.DATA].cuda())\n                inputs_test_2 = Variable(data['image2'][tio.DATA].cuda())\n                location_test = data[tio.LOCATION]\n\n                outputs_test_1, outputs_test_2 = model(inputs_test_1, inputs_test_2)\n                if args.result == 'result1':\n                    outputs_test = outputs_test_1\n                else:\n                    outputs_test = outputs_test_2\n\n                aggregator.add_batch(outputs_test, location_test)\n\n        outputs_tensor = aggregator.get_output_tensor()\n        save_test_3d(cfg['NUM_CLASSES'], outputs_tensor, subject['ID'], args.threshold, path_seg_results, subject['image']['affine'])\n\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n        print('-' * print_num)\n        print('| Testing Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('=' * print_num)"
  },
  {
    "path": "tools/Atrial/__init__.py",
    "content": ""
  },
  {
    "path": "tools/Atrial/postprocess.py",
    "content": "import numpy as np\nimport argparse\nimport os\nimport SimpleITK as sitk\nfrom skimage.morphology import remove_small_objects, remove_small_holes\nimport skimage\n\ndef save_max_objects(image):\n    labeled_image = skimage.measure.label(image)\n    labeled_list = skimage.measure.regionprops(labeled_image)\n    box = []\n    for i in range(len(labeled_list)):\n        box.append(labeled_list[i].area)\n        label_num = box.index(max(box)) + 1\n\n    labeled_image[labeled_image != label_num] = 0\n    labeled_image[labeled_image == label_num] = 1\n\n    return labeled_image\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--pred_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730')\n    parser.add_argument('--save_path', default='//10.0.5.233/shared_data//XNet/seg_pred/test/Atrial/best_DTC_Jc_0.8730_mor')\n    parser.add_argument('--fill_hole_thr', default=500, help='300-500')\n    args = parser.parse_args()\n\n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n\n    for i in os.listdir(args.pred_path):\n\n        pred_path = os.path.join(args.pred_path, i)\n        save_path = os.path.join(args.save_path, i)\n\n        pred = sitk.ReadImage(pred_path)\n        pred = sitk.GetArrayFromImage(pred)\n\n        pred = pred.astype(bool)\n        pred = remove_small_holes(pred, args.fill_hole_thr)\n        pred = pred.astype(np.uint8)\n\n        pred = save_max_objects(pred)\n        pred = pred.astype(np.uint8)\n\n        pred = sitk.GetImageFromArray(pred)\n        sitk.WriteImage(pred, save_path)"
  },
  {
    "path": "tools/Atrial/preprocess.py",
    "content": "import numpy as np\nimport torchio as tio\nimport os\nimport argparse\nfrom tqdm import tqdm\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/Training Set')\n    parser.add_argument('--save_path', default='E:/Biomedical datasets/2018 Atrial Segmentation Challenge/dataset')\n    args = parser.parse_args()\n\n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n    save_image_path = args.save_path + '/image'\n    save_mask_path = args.save_path + '/mask'\n    if not os.path.exists(save_image_path):\n        os.mkdir(save_image_path)\n    if not os.path.exists(save_mask_path):\n        os.mkdir(save_mask_path)\n\n    for i in os.listdir(args.data_path):\n        save_name = i + '.nrrd'\n        image_path = args.data_path + '/' + i + '/' + 'lgemri.nrrd'\n        mask_path = args.data_path + '/' + i + '/' + 'laendo.nrrd'\n\n        image = tio.ScalarImage(image_path)\n        mask = tio.LabelMap(mask_path)\n\n        _, w, h, d = mask.data.shape\n        tempL = np.nonzero(np.array(mask.data))\n        minx, maxx = np.min(tempL[1]), np.max(tempL[1])\n        miny, maxy = np.min(tempL[2]), np.max(tempL[2])\n        # minz, maxz = np.min(tempL[3]), np.max(tempL[3])\n\n        px = max(112 - (maxx - minx), 0) // 2\n        py = max(112 - (maxy - miny), 0) // 2\n        # pz = max(80 - (maxz - minz), 0) // 2\n        minx = max(minx - np.random.randint(10, 20) - px, 0)\n        maxx = min(maxx + np.random.randint(10, 20) + px, w)\n        miny = max(miny - np.random.randint(10, 20) - py, 0)\n        maxy = min(maxy + np.random.randint(10, 20) + py, h)\n        # minz = max(minz - np.random.randint(5, 10) - pz, 0)\n        # maxz = min(maxz + np.random.randint(5, 10) + pz, d)\n\n        image_np = image.data[:, minx:maxx, miny:maxy, :]\n        image.set_data(image_np)\n\n        mask_np = mask.data[:, minx:maxx, miny:maxy, :]\n        mask.set_data(mask_np)\n\n        print(image_np.shape)\n        image.save(os.path.join(save_image_path, save_name))\n        mask.save(os.path.join(save_mask_path, save_name))\n\n\n"
  },
  {
    "path": "tools/LiTS/__init__.py",
    "content": ""
  },
  {
    "path": "tools/LiTS/postprocess.py",
    "content": "import numpy as np\nimport argparse\nimport os\nimport SimpleITK as sitk\nfrom skimage.morphology import remove_small_objects, remove_small_holes\nimport skimage\n\ndef save_max_objects(image):\n    labeled_image = skimage.measure.label(image)\n    labeled_list = skimage.measure.regionprops(labeled_image)\n    box = []\n    for i in range(len(labeled_list)):\n        box.append(labeled_list[i].area)\n        label_num = box.index(max(box)) + 1\n\n    labeled_image[labeled_image != label_num] = 0\n    labeled_image[labeled_image == label_num] = 1\n\n    return labeled_image\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--pred_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594')\n    parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/seg_pred/test/LiTS/best_DTC_Jc_0.7594_mor')\n    parser.add_argument('--fill_hole_thr', default=100)\n    args = parser.parse_args()\n\n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n\n    for i in os.listdir(args.pred_path):\n\n        pred_path = os.path.join(args.pred_path, i)\n        save_path = os.path.join(args.save_path, i)\n\n        pred = sitk.ReadImage(pred_path)\n        pred = sitk.GetArrayFromImage(pred)\n\n        pred_ = pred.copy()\n        pred_[pred != 0] = 1\n\n        pred_ = pred_.astype(bool)\n        pred_ = remove_small_holes(pred_, args.fill_hole_thr)\n        pred_ = pred_.astype(np.uint8)\n\n        pred_ = save_max_objects(pred_)\n        pred_[(pred_ == 1) & (pred == 2)] = 2\n        pred_ = pred_.astype(np.uint8)\n\n        pred_ = sitk.GetImageFromArray(pred_)\n        sitk.WriteImage(pred_, save_path)"
  },
  {
    "path": "tools/LiTS/preprocess.py",
    "content": "import numpy as np\nimport os\nimport argparse\nfrom tqdm import tqdm\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_path', default='E:/Biomedical datasets/LiTS')\n    parser.add_argument('--save_path', default='E:/Biomedical datasets/LiTS/dataset')\n    parser.add_argument('--min_hu', default=-100)\n    parser.add_argument('--max_hu', default=250)\n    parser.add_argument('--target_spacing', default=[1.00, 1.00, 1.00])\n    parser.add_argument('--crop_pixel', default=25)\n    args = parser.parse_args()\n\n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n    save_image_path = args.save_path + '/image'\n    save_mask_path = args.save_path + '/mask'\n    if not os.path.exists(save_image_path):\n        os.mkdir(save_image_path)\n    if not os.path.exists(save_mask_path):\n        os.mkdir(save_mask_path)\n\n    image_path = args.data_path + '/image'\n    mask_path = args.data_path + '/mask'\n\n    for i in os.listdir(image_path):\n\n        image_dir = os.path.join(image_path, i)\n        mask_dir = os.path.join(mask_path, i)\n\n        image = sitk.ReadImage(image_dir)\n        mask = sitk.ReadImage(mask_dir)\n\n        size = np.array(image.GetSize())\n        spacing = np.array(image.GetSpacing())\n        new_size = size * spacing / args.target_spacing\n        new_size = [int(s) for s in new_size]\n\n        print(new_size, size)\n\n        resample_image = sitk.ResampleImageFilter()\n        resample_image.SetOutputDirection(image.GetDirection())\n        resample_image.SetOutputOrigin(image.GetOrigin())\n        resample_image.SetSize(new_size)\n        resample_image.SetOutputSpacing(args.target_spacing)\n        resample_image.SetInterpolator(sitk.sitkLinear)\n        image = resample_image.Execute(image)\n\n        resample_mask = sitk.ResampleImageFilter()\n        resample_mask.SetOutputDirection(mask.GetDirection())\n        resample_mask.SetOutputOrigin(mask.GetOrigin())\n        resample_mask.SetSize(new_size)\n        resample_mask.SetOutputSpacing(args.target_spacing)\n        resample_mask.SetInterpolator(sitk.sitkNearestNeighbor)\n        mask = resample_mask.Execute(mask)\n\n        image_np = sitk.GetArrayFromImage(image)\n        mask_np = sitk.GetArrayFromImage(mask)\n\n        w, h, d = mask_np.shape\n        templ = np.nonzero(mask_np)\n        w_min = max(np.min(templ[0]) - args.crop_pixel, 0)\n        w_max = min(np.max(templ[0]) + args.crop_pixel, w)\n        h_min = max(np.min(templ[1]) - args.crop_pixel, 0)\n        h_max = min(np.max(templ[1]) + args.crop_pixel, h)\n        d_min = max(np.min(templ[2]) - args.crop_pixel, 0)\n        d_max = min(np.max(templ[2]) + args.crop_pixel, d)\n\n        image_np = image_np[w_min:w_max, h_min:h_max, d_min:d_max]\n        # image_np = image.data\n        image_np[image_np < args.min_hu] = args.min_hu\n        image_np[image_np > args.max_hu] = args.max_hu\n\n        mask_np = mask_np[w_min:w_max, h_min:h_max, d_min:d_max]\n\n\n        image_save = sitk.GetImageFromArray(image_np)\n        image_save.SetSpacing(args.target_spacing)\n        image_save.SetDirection(image.GetDirection())\n        image_save.SetOrigin(image.GetOrigin())\n\n        mask_save = sitk.GetImageFromArray(mask_np)\n        mask_save.SetSpacing(args.target_spacing)\n        mask_save.SetDirection(image.GetDirection())\n        mask_save.SetOrigin(image.GetOrigin())\n\n        sitk.WriteImage(image_save, os.path.join(save_image_path, i))\n        sitk.WriteImage(mask_save, os.path.join(save_mask_path, i))\n        # image_save.save(os.path.join(save_image_path, save_name))\n        # mask_save.save(os.path.join(save_mask_path, save_name))\n\n\n"
  },
  {
    "path": "tools/LiTS/split_train_val.py",
    "content": "import numpy as np\nimport os\nimport argparse\nimport shutil\nimport random\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/image')\n    parser.add_argument('--mask_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/mask')\n    parser.add_argument('--save_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val')\n    parser.add_argument('--amount', default=31)\n    parser.add_argument('--random_seed', default=10)\n    args = parser.parse_args()\n\n    random.seed(args.random_seed)\n\n    if not os.path.exists(args.save_path):\n        os.mkdir(args.save_path)\n    save_image_path = args.save_path + '/image'\n    save_mask_path = args.save_path + '/mask'\n    if not os.path.exists(save_image_path):\n        os.mkdir(save_image_path)\n    if not os.path.exists(save_mask_path):\n        os.mkdir(save_mask_path)\n\n    image_path_list = os.listdir(args.image_path)\n\n    image_path_list = random.sample(image_path_list, args.amount)\n\n    for i in image_path_list:\n        shutil.move(os.path.join(args.image_path, i), save_image_path)\n        shutil.move(os.path.join(args.mask_path, i), save_mask_path)\n"
  },
  {
    "path": "tools/__init__.py",
    "content": ""
  },
  {
    "path": "tools/eval.py",
    "content": "from sklearn.metrics import confusion_matrix\nimport numpy as np\nimport argparse\nimport os\nfrom PIL import Image\nfrom medpy.metric.binary import hd95, assd\nimport albumentations as A\nimport SimpleITK as sitk\n\ndef eval_distance(mask_list, seg_result_list, num_classes):\n\n    print_num = 42 + (num_classes - 3) * 7\n    print_num_minus = print_num - 2\n\n    assert len(mask_list) == len(seg_result_list)\n    if num_classes == 2:\n        hd_list = []\n        sd_list = []\n        for i in range(len(mask_list)):\n\n            if np.any(seg_result_list[i]) and np.any(mask_list[i]):\n\n                hd_ = hd95(seg_result_list[i], mask_list[i])\n                sd_ = assd(seg_result_list[i], mask_list[i])\n                hd_list.append(hd_)\n                sd_list.append(sd_)\n\n        hd = np.mean(hd_list)\n        sd = np.mean(sd_list)\n\n        print('| Hd: {:.4f}'.format(hd).ljust(print_num_minus, ' '), '|')\n        print('| Sd: {:.4f}'.format(sd).ljust(print_num_minus, ' '), '|')\n\n    else:\n        hd_list = []\n        sd_list = []\n\n        for cls in range(num_classes-1):\n\n            hd_list_ = []\n            sd_list_ = []\n\n            for i in range(len(mask_list)):\n\n                mask_list_ = mask_list[i].copy()\n                seg_result_list_ = seg_result_list[i].copy()\n\n                mask_list_[mask_list[i] != (cls + 1)] = 0\n                seg_result_list_[seg_result_list[i] != (cls + 1)] = 0\n\n                if np.any(seg_result_list_) and np.any(mask_list_):\n                    hd_ = hd95(seg_result_list_, mask_list_)\n                    sd_ = assd(seg_result_list_, mask_list_)\n                    hd_list_.append(hd_)\n                    sd_list_.append(sd_)\n\n            hd = np.mean(hd_list_)\n            sd = np.mean(sd_list_)\n\n            hd_list.append(hd)\n            sd_list.append(sd)\n\n        hd_list = np.array(hd_list)\n        sd_list = np.array(sd_list)\n\n        m_hd = np.mean(hd_list)\n        m_sd = np.mean(sd_list)\n\n        np.set_printoptions(precision=4, suppress=True)\n        print('|  Hd: {}'.format(hd_list).ljust(print_num_minus, ' '), '|')\n        print('|  Sd: {}'.format(sd_list).ljust(print_num_minus, ' '), '|')\n        print('| mHd: {:.4f}'.format(m_hd).ljust(print_num_minus, ' '), '|')\n        print('| mSd: {:.4f}'.format(m_sd).ljust(print_num_minus, ' '), '|')\n\n    print('-' * print_num)\n\ndef eval_pixel(mask_list, seg_result_list, num_classes):\n\n    c = confusion_matrix(mask_list, seg_result_list)\n\n    hist_diag = np.diag(c)\n    hist_sum_0 = c.sum(axis=0)\n    hist_sum_1 = c.sum(axis=1)\n\n    jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag)\n    dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0)\n\n    print_num = 42 + (num_classes - 3) * 7\n    print_num_minus = print_num - 2\n\n    print('-' * print_num)\n    if num_classes > 2:\n        m_jaccard = np.nanmean(jaccard)\n        m_dice = np.nanmean(dice)\n        np.set_printoptions(precision=4, suppress=True)\n        print('|  Jc: {}'.format(jaccard).ljust(print_num_minus, ' '), '|')\n        print('|  Dc: {}'.format(dice).ljust(print_num_minus, ' '), '|')\n        print('| mJc: {:.4f}'.format(m_jaccard).ljust(print_num_minus, ' '), '|')\n        print('| mDc: {:.4f}'.format(m_dice).ljust(print_num_minus, ' '), '|')\n    else:\n        print('| Jc: {:.4f}'.format(jaccard[1]).ljust(print_num_minus, ' '), '|')\n        print('| Dc: {:.4f}'.format(dice[1]).ljust(print_num_minus, ' '), '|')\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--pred_path', default='/mnt/data1/XNet/seg_pred/test/LiTS/best_result1_Jc_0.7677_mor')\n    parser.add_argument('--mask_path', default='/mnt/data1/XNet/dataset/LiTS/val/mask')\n    parser.add_argument('--if_3D', default=True)\n    parser.add_argument('--resize_shape', default=(128, 128))\n    parser.add_argument('--num_classes', default=3)\n    args = parser.parse_args()\n\n    pred_list = []\n    mask_list = []\n\n    pred_flatten_list = []\n    mask_flatten_list = []\n\n    num = 0\n\n    for i in os.listdir(args.pred_path):\n        pred_path = os.path.join(args.pred_path, i)\n        mask_path = os.path.join(args.mask_path, i)\n\n        if args.if_3D:\n            pred = sitk.ReadImage(pred_path)\n            pred = sitk.GetArrayFromImage(pred)\n\n            mask = sitk.ReadImage(mask_path)\n            mask = sitk.GetArrayFromImage(mask)\n\n        else:\n            pred = Image.open(pred_path)\n            # pred = pred.resize((args.resize_shape[1], args.resize_shape[0]))\n            pred = np.array(pred)\n\n            mask = Image.open(mask_path)\n            # mask = mask.resize((args.resize_shape[1], args.resize_shape[0]))\n            mask = np.array(mask)\n            resize = A.Resize(args.resize_shape[1], args.resize_shape[0], p=1)(image=pred, mask=mask)\n            mask = resize['mask']\n            pred = resize['image']\n\n\n        pred_list.append(pred)\n        mask_list.append(mask)\n\n        if num == 0:\n            pred_flatten_list = pred.flatten()\n            mask_flatten_list = mask.flatten()\n        else:\n            pred_flatten_list = np.append(pred_flatten_list, pred.flatten())\n            mask_flatten_list = np.append(mask_flatten_list, mask.flatten())\n\n        num += 1\n\n    eval_pixel(mask_flatten_list, pred_flatten_list, args.num_classes)\n    eval_distance(mask_list, pred_list, args.num_classes)\n\n"
  },
  {
    "path": "tools/mask2sdf.py",
    "content": "import numpy as np\nimport os\nimport argparse\nimport SimpleITK as sitk\nfrom scipy.ndimage import distance_transform_edt\nfrom skimage import segmentation\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val')\n    parser.add_argument('--num_classes', default=3)\n    args = parser.parse_args()\n\n    mask_path = args.data_path + '/mask'\n\n    for i in range(args.num_classes-1):\n\n        save_sdf_mask_path = args.data_path + '/mask_sdf' + str(i+1)\n        if not os.path.exists(save_sdf_mask_path):\n            os.mkdir(save_sdf_mask_path)\n\n        for j in os.listdir(mask_path):\n\n            mask = sitk.ReadImage(os.path.join(mask_path, j))\n            mask_np = sitk.GetArrayFromImage(mask)\n\n            mask_np[mask_np != (i+1)] = 0\n            mask_np = mask_np.astype(bool)\n            if mask_np.any():\n                mask_neg = ~mask_np\n                posdis = distance_transform_edt(mask_np)\n                negdis = distance_transform_edt(mask_neg)\n                boundary = segmentation.find_boundaries(mask_np, mode='inner').astype(np.uint8)\n                sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))\n                sdf[boundary == 1] = 0\n                # sdf = ((sdf - np.min(sdf)) / (np.max(sdf) - np.min(sdf))) * 255\n            else:\n                sdf = np.zeros(mask_np.shape)\n\n            sdf = sitk.GetImageFromArray(sdf)\n            sdf.SetSpacing(mask.GetSpacing())\n            sdf.SetDirection(mask.GetDirection())\n            sdf.SetOrigin(mask.GetOrigin())\n            sitk.WriteImage(sdf, os.path.join(save_sdf_mask_path, j))\n\n\n\n\n\n\n"
  },
  {
    "path": "tools/res_image_mask.py",
    "content": "import numpy as np\nimport os\nimport argparse\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data_path', default='//10.0.5.233/shared_data/XNet/dataset/Atrial/train_sup_100')\n    args = parser.parse_args()\n\n    image_path = args.data_path + '/image'\n    mask_path = args.data_path + '/mask'\n\n    save_res_path = args.data_path + '/image_res'\n    save_res_mask_path = args.data_path + '/mask_res'\n    if not os.path.exists(save_res_path):\n        os.mkdir(save_res_path)\n    if not os.path.exists(save_res_mask_path):\n        os.mkdir(save_res_mask_path)\n\n    for i in os.listdir(image_path):\n\n        image = sitk.ReadImage(os.path.join(image_path, i))\n        image_np = sitk.GetArrayFromImage(image)\n        mask = sitk.ReadImage(os.path.join(mask_path, i))\n        mask_np = sitk.GetArrayFromImage(mask)\n\n        image_copy = np.zeros(image_np.shape)\n        image_copy[1:, :, :] = image_np[0:image_np.shape[0] - 1, :, :]\n        image_res = image_np - image_copy\n        image_res[0, :, :] = 0\n        image_res = np.abs(image_res)\n        image_res = sitk.GetImageFromArray(image_res)\n        image_res.SetSpacing(image.GetSpacing())\n        image_res.SetDirection(image.GetDirection())\n        image_res.SetOrigin(image.GetOrigin())\n\n        mask_copy = np.zeros(mask_np.shape)\n        mask_copy[1:, :, :] = mask_np[0:mask_np.shape[0] - 1, :, :]\n        mask_res = mask_np - mask_copy\n        mask_res[0, :, :] = 0\n        mask_res = np.abs(mask_res)\n        mask_res = sitk.GetImageFromArray(mask_res)\n        mask_res.SetSpacing(image.GetSpacing())\n        mask_res.SetDirection(image.GetDirection())\n        mask_res.SetOrigin(image.GetOrigin())\n\n        sitk.WriteImage(image_res, os.path.join(save_res_path, i))\n        sitk.WriteImage(mask_res, os.path.join(save_res_mask_path, i))\n\n\n\n"
  },
  {
    "path": "tools/wavelet2D.py",
    "content": "import numpy as np\nfrom PIL import Image\nimport pywt\nimport argparse\nimport os\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/image')\n    parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/L')\n    parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/CREMI/train_unsup_80/H')\n    parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')\n    parser.add_argument('--if_RGB', default=False)\n    args = parser.parse_args()\n\n    if not os.path.exists(args.L_path):\n        os.mkdir(args.L_path)\n    if not os.path.exists(args.H_path):\n        os.mkdir(args.H_path)\n\n    for i in os.listdir(args.image_path):\n        image_path = os.path.join(args.image_path, i)\n        L_path = os.path.join(args.L_path, i)\n        H_path = os.path.join(args.H_path, i)\n\n        if args.if_RGB:\n            image = Image.open(image_path).convert('L')\n        else:\n            image = Image.open(image_path)\n        image = np.array(image)\n\n        LL, (LH, HL, HH) = pywt.dwt2(image, args.wavelet_type)\n\n        LL = (LL - LL.min()) / (LL.max() - LL.min()) * 255\n\n        LL = Image.fromarray(LL.astype(np.uint8))\n        LL.save(L_path)\n\n        LH = (LH - LH.min()) / (LH.max() - LH.min()) * 255\n        HL = (HL - HL.min()) / (HL.max() - HL.min()) * 255\n        HH = (HH - HH.min()) / (HH.max() - HH.min()) * 255\n\n        merge1 = HH + HL + LH\n        merge1 = (merge1-merge1.min()) / (merge1.max()-merge1.min()) * 255\n\n        merge1 = Image.fromarray(merge1.astype(np.uint8))\n        merge1.save(H_path)\n\n"
  },
  {
    "path": "tools/wavelet3D.py",
    "content": "import numpy as np\nfrom PIL import Image\nimport pywt\nimport argparse\nimport os\nimport SimpleITK as sitk\nimport torchio as tio\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--image_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/val/image')\n    parser.add_argument('--L_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/L')\n    parser.add_argument('--H_path', default='//10.0.5.233/shared_data/XNet/dataset/LiTS/train_sup_100/H')\n    parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')\n\n    args = parser.parse_args()\n\n    if not os.path.exists(args.L_path):\n        os.mkdir(args.L_path)\n\n    if not os.path.exists(args.H_path):\n        os.mkdir(args.H_path)\n\n    for i in os.listdir(args.image_path):\n        image_path = os.path.join(args.image_path, i)\n        L_path = os.path.join(args.L_path, i)\n        H_path = os.path.join(args.H_path, i)\n\n        image = sitk.ReadImage(image_path)\n        image_np = sitk.GetArrayFromImage(image)\n\n        image_wave = pywt.dwtn(image_np, args.wavelet_type)\n        LLL = image_wave['aaa']\n        LLH = image_wave['aad']\n        LHL = image_wave['ada']\n        LHH = image_wave['add']\n        HLL = image_wave['daa']\n        HLH = image_wave['dad']\n        HHL = image_wave['dda']\n        HHH = image_wave['ddd']\n\n        LLL = (LLL - LLL.min()) / (LLL.max() - LLL.min()) * 255\n\n        resample_image = sitk.ResampleImageFilter()\n        resample_image.SetSize(image.GetSize())\n        resample_image.SetOutputSpacing([0.5, 0.5, 0.5])\n        resample_image.SetInterpolator(sitk.sitkLinear)\n        LLL = resample_image.Execute(LLL)\n\n        LLL.SetSpacing(image.GetSpacing())\n        LLL.SetDirection(image.GetDirection())\n        LLL.SetOrigin(image.GetOrigin())\n\n        sitk.WriteImage(LLL, L_path)\n\n\n        LLH = (LLH - LLH.min()) / (LLH.max() - LLH.min()) * 255\n        LHL = (LHL - LHL.min()) / (LHL.max() - LHL.min()) * 255\n        LHH = (LHH - LHH.min()) / (LHH.max() - LHH.min()) * 255\n        HLL = (HLL - HLL.min()) / (HLL.max() - HLL.min()) * 255\n        HLH = (HLH - HLH.min()) / (HLH.max() - HLH.min()) * 255\n        HHL = (HHL - HHL.min()) / (HHL.max() - HHL.min()) * 255\n        HHH = (HHH - HHH.min()) / (HHH.max() - HHH.min()) * 255\n\n        merge1 = LLH + LHL + LHH + HLL + HLH + HHL + HHH\n        merge1 = (merge1 - merge1.min()) / (merge1.max() - merge1.min()) * 255\n\n        merge1 = sitk.GetImageFromArray(merge1)\n\n        resample_image = sitk.ResampleImageFilter()\n        resample_image.SetSize(image.GetSize())\n        resample_image.SetOutputSpacing([0.5, 0.5, 0.5])\n        resample_image.SetInterpolator(sitk.sitkLinear)\n        merge1 = resample_image.Execute(merge1)\n\n        merge1.SetSpacing(image.GetSpacing())\n        merge1.SetDirection(image.GetDirection())\n        merge1.SetOrigin(image.GetOrigin())\n\n        sitk.WriteImage(merge1, H_path)\n\n\n"
  },
  {
    "path": "train_semi_CCT.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet_cct', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n    kl_distance = nn.KLDivLoss(reduction='none')\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n            pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)\n            pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)\n\n            consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2)\n            consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2)\n            consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2)\n\n            loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup) + criterion(pred_train_sup2, mask_train_sup) + criterion(pred_train_sup3, mask_train_sup) + criterion(pred_train_sup4, mask_train_sup)) / 4\n            loss_train_sup = loss_train_sup1\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    val_loss_sup_1 += loss_val_sup1.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'CCT')\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_CCT_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=8, type=int)\n    parser.add_argument('--samples_per_volume_val', default=12, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d_cct', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'CCT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CCT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight)+ '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank], find_unused_parameters=True)\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n    kl_distance = nn.KLDivLoss(reduction='none')\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n            pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)\n            pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)\n\n            consistency_loss_aux1 = torch.mean((pred_train_unsup1 - pred_train_unsup2) ** 2)\n            consistency_loss_aux2 = torch.mean((pred_train_unsup1 - pred_train_unsup3) ** 2)\n            consistency_loss_aux3 = torch.mean((pred_train_unsup1 - pred_train_unsup4) ** 2)\n\n            loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4\n            loss_train_sup = loss_train_sup1\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    val_loss_sup_1 += loss_val_sup_1.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'CCT', cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_CPS.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet_sb', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CPS-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_'+args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)\n\n\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n            \n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup = unsup_index['image']\n            img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True))\n            \n            optimizer1.zero_grad()\n            optimizer2.zero_grad()\n            \n            pred_train_unsup1 = model1(img_train_unsup)\n            pred_train_unsup2 = model2(img_train_unsup)\n\n            max_train1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n            \n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1 = model1(img_train_sup)\n            pred_train_sup2 = model2(img_train_sup)\n            \n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n            \n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n            \n            optimizer1.step()\n            optimizer2.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        scheduler_warmup2.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    optimizer2.zero_grad()\n\n                    outputs_val1 = model1(inputs_val)\n                    outputs_val2 = model2(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)\n                        visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)\n\n"
  },
  {
    "path": "train_semi_CPS_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=8, type=int)\n    parser.add_argument('--samples_per_volume_val', default=12, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d_min', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+'CPS'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CPS-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)\n\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n            optimizer2.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup2 = model2(img_train_unsup_1)\n\n            max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train_unsup1 = max_train_unsup1.long()\n            max_train_unsup2 = max_train_unsup2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1 = model1(img_train_sup_1)\n            pred_train_sup2 = model2(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            optimizer2.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        scheduler_warmup2.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda().cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    optimizer2.zero_grad()\n\n                    outputs_val_1 = model1(inputs_val_1)\n                    outputs_val_2 = model2(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_CT.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n1', '--network1', default='unet', type=str)\n    parser.add_argument('-n2', '--network2', default='swinunet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'CT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)\n\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup = unsup_index['image']\n            img_train_unsup = Variable(img_train_unsup.cuda(non_blocking=True))\n\n            optimizer1.zero_grad()\n            optimizer2.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup)\n            pred_train_unsup2 = model2(img_train_unsup)\n\n            max_train1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1 = model1(img_train_sup)\n            pred_train_sup2 = model2(img_train_sup)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            optimizer2.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        scheduler_warmup2.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    optimizer2.zero_grad()\n\n                    outputs_val1 = model1(inputs_val)\n                    outputs_val2 = model2(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)\n                        visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)\n\n"
  },
  {
    "path": "train_semi_CT_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=8, type=int)\n    parser.add_argument('--samples_per_volume_val', default=12, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n1', '--network1', default='unet3d_min', type=str)\n    parser.add_argument('-n2', '--network2', default='unertr', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results+'/' +str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+'CT'+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+ '-cw' + str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-CT-' + str(os.path.split(args.path_dataset)[1]) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network1, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network2, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank], find_unused_parameters=True)\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)\n\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n            optimizer2.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup2 = model2(img_train_unsup_1)\n\n            max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train_unsup1 = max_train_unsup1.long()\n            max_train_unsup2 = max_train_unsup2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1 = model1(img_train_sup_1)\n            pred_train_sup2 = model2(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            optimizer2.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        scheduler_warmup2.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    optimizer2.zero_grad()\n\n                    outputs_val_1 = model1(inputs_val_1)\n                    outputs_val_2 = model2(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_DTC.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it_dtc\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/LiTS')\n    parser.add_argument('--dataset_name', default='LiTS', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--beta', default=0.3, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(112, 112, 32))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=8, type=int)\n    parser.add_argument('--samples_per_volume_val', default=12, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d_dtc', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'DTC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-DTC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it_dtc(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        num_classes=cfg['NUM_CLASSES'],\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it_dtc(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        num_classes=cfg['NUM_CLASSES'],\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it_dtc(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        num_classes=cfg['NUM_CLASSES'],\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n    mseloss = torch.nn.MSELoss().cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n            pred_train_unsup_sdf, pred_train_unsup_seg = model1(img_train_unsup_1)\n\n            pred_train_unsup_seg_soft = torch.sigmoid(pred_train_unsup_seg)\n            dis_to_mask = torch.sigmoid(-1500 * pred_train_unsup_sdf)\n\n            loss_train_unsup = torch.mean((dis_to_mask - pred_train_unsup_seg_soft) ** 2)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n            mask_train_sup_sdf1 = Variable(sup_index['mask2'][tio.DATA].squeeze(1).float().cuda())\n            if cfg['NUM_CLASSES'] == 3:\n                mask_train_sup_sdf2 = Variable(sup_index['mask3'][tio.DATA].squeeze(1).float().cuda())\n\n            pred_train_sup_sdf, pred_train_sup_seg = model1(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup_seg\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup_seg), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            if cfg['NUM_CLASSES'] == 3:\n                loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1) + mseloss(pred_train_sup_sdf[:, 2, ...], mask_train_sup_sdf2)\n            else:\n                loss_train_sdf = mseloss(pred_train_sup_sdf[:, 1, ...], mask_train_sup_sdf1)\n            loss_train_seg = criterion(pred_train_sup_seg, mask_train_sup)\n            loss_train_sup = loss_train_seg + args.beta * loss_train_sdf\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup += loss_train_sup.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    outputs_val_sdf, outputs_val_seg = model1(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_seg\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_seg), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_seg, mask_val)\n                    val_loss_sup_1 += loss_val_sup_1.item()\n\n                torch.cuda.empty_cache()\n\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'DTC', cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_EM.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss, entropy_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n\n            loss_train_unsup = entropy_loss(pred_train_unsup1, C=2)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1 = model1(img_train_sup)\n            pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2)\n\n            loss_train_sup = loss_train_sup1\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    outputs_val1 = model1(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    val_loss_sup_1 += loss_val_sup1.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'EM')\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_EM_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss, entropy_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=50, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'EM' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-EM-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n\n            loss_train_unsup = entropy_loss(pred_train_unsup1, C=2)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1 = model1(img_train_sup_1)\n            pred_train_sup1_soft = torch.softmax(pred_train_sup1, 1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup) + entropy_loss(pred_train_sup1_soft, C=2)\n            loss_train_sup = loss_train_sup1\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    outputs_val_1 = model1(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    val_loss_sup_1 += loss_val_sup_1.item()\n\n                torch.cuda.empty_cache()\n\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'EM', cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(\n                        print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_MT.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef update_ema_variables(model, ema_model, alpha, global_step):\n    # Use the true average until the exponential average is more correct\n    alpha = min(1 - 1 / (global_step + 1), alpha)\n    for ema_param, param in zip(ema_model.parameters(), model.parameters()):\n        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--ema_decay', default=0.99, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    # for param in model2.parameters():\n    #     param.detach_()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n\n            noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)\n            img_train_unsup_2 = img_train_unsup_1 + noise\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            with torch.no_grad():\n                pred_train_unsup2 = model2(img_train_unsup_2)\n                pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n\n            loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1 = model1(img_train_sup)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            update_ema_variables(model1, model2, args.ema_decay, epoch)\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)\n                train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n\n                    outputs_val1 = model1(inputs_val)\n                    outputs_val2 = model2(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2)\n                        visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_MT_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef update_ema_variables(model, ema_model, alpha, global_step):\n    # Use the true average until the exponential average is more correct\n    alpha = min(1 - 1 / (global_step + 1), alpha)\n    for ema_param, param in zip(ema_model.parameters(), model.parameters()):\n        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--ema_decay', default=0.99, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'MT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-MT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    # for param in model2.parameters():\n    #     param.detach_()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)\n            img_train_unsup_2 = img_train_unsup_1 + noise\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            with torch.no_grad():\n                pred_train_unsup2 = model2(img_train_unsup_2)\n                pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n\n            loss_train_unsup = torch.mean((pred_train_unsup1 - pred_train_unsup2)**2)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1 = model1(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup = loss_train_sup1\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            update_ema_variables(model1, model2, args.ema_decay, epoch)\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n\n                    outputs_val_1 = model1(inputs_val_1)\n                    outputs_val_2 = model2(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val,  val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_UAMT.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss, softmax_mse_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT, visual_image_MT\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_2d, draw_pred_MT, print_best\nfrom warnings import simplefilter\nfrom config.ramps import ramps\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef update_ema_variables(model, ema_model, alpha, global_step):\n    # Use the true average until the exponential average is more correct\n    alpha = min(1 - 1 / (global_step + 1), alpha)\n    for ema_param, param in zip(ema_model.parameters(), model.parameters()):\n        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=0.05, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--ema_decay', default=0.99, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    # for param in model2.parameters():\n    #     param.detach_()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n\n            noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)\n            img_train_unsup_2 = img_train_unsup_1 + noise\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            with torch.no_grad():\n                pred_train_unsup2 = model2(img_train_unsup_2)\n\n            T = 8\n\n            _, _, w, h = img_train_unsup_1.shape\n            volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1)\n            stride = volume_batch_r.shape[0] // 2\n            preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], w, h]).cuda()\n            for i_ in range(T // 2):\n                ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)\n                with torch.no_grad():\n                    preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs)\n            preds = torch.softmax(preds, dim=1)\n            preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], w, h)\n            preds = torch.mean(preds, dim=0)\n            uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)\n\n            consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2)  # (batch, 2, 112,112,80)\n            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2)\n            mask = (uncertainty < threshold).float()\n            loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1 = model1(img_train_sup)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            update_ema_variables(model1, model2, args.ema_decay, epoch)\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)\n                train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n\n                    outputs_val1 = model1(inputs_val)\n                    outputs_val2 = model2(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_MT(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2)\n                        visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_MT(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_UAMT_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_MT, print_val_loss, print_train_eval_sup, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_MT, visualization_MT\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom config.ramps import ramps\nfrom loss.loss_function import segmentation_loss, softmax_mse_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef update_ema_variables(model, ema_model, alpha, global_step):\n    # Use the true average until the exponential average is more correct\n    alpha = min(1 - 1 / (global_step + 1), alpha)\n    for ema_param, param in zip(ema_model.parameters(), model.parameters()):\n        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--ema_decay', default=0.99, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'UAMT' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-UAMT-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_MT(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    # for param in model2.parameters():\n    #     param.detach_()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            noise = torch.clamp(torch.randn_like(img_train_unsup_1) * 0.1, -0.2, 0.2)\n            img_train_unsup_2 = img_train_unsup_1 + noise\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1 = model1(img_train_unsup_1)\n            with torch.no_grad():\n                pred_train_unsup2 = model2(img_train_unsup_2)\n\n            T = 8\n\n            _, _, d, w, h = img_train_unsup_1.shape\n            volume_batch_r = img_train_unsup_1.repeat(2, 1, 1, 1, 1)\n            stride = volume_batch_r.shape[0] // 2\n            preds = torch.zeros([stride * T, cfg['NUM_CLASSES'], d, w, h]).cuda()\n            for i_ in range(T // 2):\n                ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)\n                with torch.no_grad():\n                    preds[2 * stride * i_:2 * stride * (i_ + 1)] = model2(ema_inputs)\n            preds = torch.softmax(preds, dim=1)\n            preds = preds.reshape(T, stride, cfg['NUM_CLASSES'], d, w, h)\n            preds = torch.mean(preds, dim=0)\n            uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)\n\n            consistency_dist = softmax_mse_loss(pred_train_unsup1, pred_train_unsup2)  # (batch, 2, 112,112,80)\n            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(epoch, args.num_epochs)) * np.log(2)\n            mask = (uncertainty < threshold).float()\n            loss_train_unsup = torch.sum(mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1 = model1(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup = loss_train_sup1\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            update_ema_variables(model1, model2, args.ema_decay, epoch)\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_MT(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_half, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n\n                    outputs_val_1 = model1(inputs_val_1)\n                    outputs_val_2 = model2(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val,  val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_MT(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(\n                        print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_URPC.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss, entropy_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=1, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet_urpc', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-URPC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train_unsup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n    kl_distance = nn.KLDivLoss(reduction='none')\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n            pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)\n            pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)\n\n            preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4\n\n            variance_aux1 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup1), dim=1, keepdim=True)\n            exp_variance_aux1 = torch.exp(-variance_aux1)\n            variance_aux2 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup2), dim=1, keepdim=True)\n            exp_variance_aux2 = torch.exp(-variance_aux2)\n            variance_aux3 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup3), dim=1, keepdim=True)\n            exp_variance_aux3 = torch.exp(-variance_aux3)\n            variance_aux4 = torch.sum(kl_distance(torch.log(preds), pred_train_unsup4), dim=1, keepdim=True)\n            exp_variance_aux4 = torch.exp(-variance_aux4)\n\n            consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2\n            consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)\n            consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2\n            consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)\n            consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2\n            consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)\n            consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2\n            consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)\n            loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup = sup_index['image']\n            img_train_sup = Variable(img_train_sup.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4\n            loss_train_sup = loss_train_sup1\n            loss_train_sup.backward()\n\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list1, train_m_jc1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val = Variable(data['image'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    outputs_val1, outputs_val2, outputs_val3, outputs_val4 = model1(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    val_loss_sup_1 += loss_val_sup1.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list1, val_m_jc1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val1, name_list_val, val_eval_list1, path_trained_models, path_seg_results, cfg['PALETTE'], 'URPC')\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, outputs_val1, train_eval_list1, val_eval_list1)\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_cps, train_m_jc1, val_epoch_loss_sup1, val_m_jc1)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_URPC_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_EM, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_EM, visualization_EM\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-c', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet3d_urpc', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + 'URPC' + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-UPRC-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration)+ '-' + str(args.sup_mark) + '-' + str(args.unsup_mark) + '-' + str(args.input1))\n        visdom = visdom_initialization_EM(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_it(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n    kl_distance = nn.KLDivLoss(reduction='none')\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model1.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n\n            optimizer1.zero_grad()\n\n            pred_train_unsup1, pred_train_unsup2, pred_train_unsup3, pred_train_unsup4 = model1(img_train_unsup_1)\n            pred_train_unsup1 = torch.softmax(pred_train_unsup1, 1)\n            pred_train_unsup2 = torch.softmax(pred_train_unsup2, 1)\n            pred_train_unsup3 = torch.softmax(pred_train_unsup3, 1)\n            pred_train_unsup4 = torch.softmax(pred_train_unsup4, 1)\n\n            preds = (pred_train_unsup1 + pred_train_unsup2 + pred_train_unsup3 + pred_train_unsup4) / 4\n\n            variance_aux1 = torch.sum(kl_distance(torch.log(pred_train_unsup1), preds), dim=1, keepdim=True)\n            exp_variance_aux1 = torch.exp(-variance_aux1)\n\n            variance_aux2 = torch.sum(kl_distance(torch.log(pred_train_unsup2), preds), dim=1, keepdim=True)\n            exp_variance_aux2 = torch.exp(-variance_aux2)\n\n            variance_aux3 = torch.sum(kl_distance(torch.log(pred_train_unsup3), preds), dim=1, keepdim=True)\n            exp_variance_aux3 = torch.exp(-variance_aux3)\n\n            variance_aux4 = torch.sum(kl_distance(torch.log(pred_train_unsup4), preds), dim=1, keepdim=True)\n            exp_variance_aux4 = torch.exp(-variance_aux4)\n\n            consistency_dist_aux1 = (preds - pred_train_unsup1) ** 2\n            consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)\n\n            consistency_dist_aux2 = (preds - pred_train_unsup2) ** 2\n            consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)\n\n            consistency_dist_aux3 = (preds - pred_train_unsup3) ** 2\n            consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)\n\n            consistency_dist_aux4 = (preds - pred_train_unsup4) ** 2\n            consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)\n\n            loss_train_unsup = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4\n\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1, pred_train_sup2, pred_train_sup3, pred_train_sup4 = model1(img_train_sup_1)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = (criterion(pred_train_sup1, mask_train_sup)+criterion(pred_train_sup2, mask_train_sup)+criterion(pred_train_sup3, mask_train_sup)+criterion(pred_train_sup4, mask_train_sup)) / 4\n            loss_train_sup = loss_train_sup1\n\n            loss_train_sup.backward()\n            optimizer1.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_cps, train_epoch_loss = print_train_loss_EM(train_loss_sup_1, train_loss_unsup, train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list_1, train_m_jc_1 = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train1, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer1.zero_grad()\n                    outputs_val_1, outputs_val_2, outputs_val_3, outputs_val_4 = model1(inputs_val_1)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    val_loss_sup_1 += loss_val_sup_1.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1 = print_val_loss_sup(val_loss_sup_1, num_batches, print_num, print_num_minus)\n                    val_eval_list_1, val_m_jc_1 = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val_1, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model1, score_list_val_1, mask_list_val, val_eval_list_1, path_trained_models, path_seg_results, path_mask_results, 'URPC', cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_EM(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_cps, train_m_jc_1, val_epoch_loss_sup_1, val_m_jc_1)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_semi_XNet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom models.getnetwork import get_network\nimport argparse\nimport time\nimport os\nimport numpy as np\nfrom torch.backends import cudnn\nimport random\nfrom PIL import Image\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nimport sys\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_iitnn\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet')\n    parser.add_argument('-pd', '--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice')\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672)\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    # trained model save\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    # seg results save\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    # vis\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+ str(args.sup_mark)+'-'+str(args.unsup_mark)+'-'+str(args.input1)+'-'+str(args.input2))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    if args.input2 == 'image':\n        input2_mean = 'MEAN'\n        input2_std = 'STD'\n    else:\n        input2_mean = 'MEAN_' + args.input2\n        input2_std = 'STD_' + args.input2\n\n    data_transforms = data_transform_2d()\n    data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n    data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])\n\n    dataset_train_unsup = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/train_unsup_'+args.unsup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=False,\n        num_images=None,\n    )\n    num_images_unsup = len(dataset_train_unsup)\n\n    dataset_train_sup = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/train_sup_'+args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=num_images_unsup,\n    )\n    dataset_val = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup, shuffle=True)\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n    dist.barrier()\n\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    since = time.time()\n    count_iter = 0\n\n    best_model = model\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch+1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = unsup_index['image']\n            img_train_unsup_1 = Variable(img_train_unsup_1.cuda(non_blocking=True))\n            img_train_unsup_2 = unsup_index['image_2']\n            img_train_unsup_2 = Variable(img_train_unsup_2.cuda(non_blocking=True))\n\n            optimizer.zero_grad()\n            pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2)\n\n            max_train1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train2) + criterion(pred_train_unsup2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = sup_index['image']\n            img_train_sup_1 = Variable(img_train_sup_1.cuda(non_blocking=True))\n            img_train_sup_2 = sup_index['image_2']\n            img_train_sup_2 = Variable(img_train_sup_2.cuda(non_blocking=True))\n            mask_train_sup = sup_index['mask']\n            mask_train_sup = Variable(mask_train_sup.cuda(non_blocking=True))\n\n            pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2)\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n            optimizer.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch+1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val'] / 16:\n\n                    inputs_val_1 = Variable(data['image'].cuda(non_blocking=True))\n                    inputs_val_2 = Variable(data['image_2'].cuda(non_blocking=True))\n                    mask_val = Variable(data['mask'].cuda(non_blocking=True))\n                    name_val = data['ID']\n\n                    optimizer.zero_grad()\n\n                    outputs_val1, outputs_val2 = model(inputs_val_1, inputs_val_2)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train_sup, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)\n                        visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time()-begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)\n\n"
  },
  {
    "path": "train_semi_XNet3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_iit\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/semi_xnet')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/semi_xnet')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--input2', default='DB2_H')\n    parser.add_argument('--sup_mark', default='20')\n    parser.add_argument('--unsup_mark', default='80')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet3d', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results_1 = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results_1) and rank == args.rank_index:\n        os.mkdir(path_seg_results_1)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Semi-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+'-cw='+str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + str(args.input1) + '-' + str(args.input2))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_unsup = dataset_iit(\n        data_dir=args.path_dataset + '/train_unsup_' + args.unsup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=False,\n        num_images=None\n    )\n    num_images_unsup = len(dataset_train_unsup.dataset_1)\n\n    dataset_train_sup = dataset_iit(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=num_images_unsup\n    )\n    dataset_val = dataset_iit(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler_unsup = torch.utils.data.distributed.DistributedSampler(dataset_train_unsup.queue_train_set_1, shuffle=True)\n    train_sampler_sup = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train_sup'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_sup)\n    dataloaders['train_unsup'] = DataLoader(dataset_train_unsup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler_unsup)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train_sup']), 'train_unsup': len(dataloaders['train_unsup']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5 * 10 ** args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train_sup'].sampler.set_epoch(epoch)\n        dataloaders['train_unsup'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        dataset_train_sup = iter(dataloaders['train_sup'])\n        dataset_train_unsup = iter(dataloaders['train_unsup'])\n\n        for i in range(num_batches['train_sup']):\n\n            unsup_index = next(dataset_train_unsup)\n            img_train_unsup_1 = Variable(unsup_index['image'][tio.DATA].cuda())\n            img_train_unsup_2 = Variable(unsup_index['image2'][tio.DATA].cuda())\n\n            optimizer.zero_grad()\n            pred_train_unsup1, pred_train_unsup2 = model(img_train_unsup_1, img_train_unsup_2)\n\n            max_train_unsup1 = torch.max(pred_train_unsup1, dim=1)[1]\n            max_train_unsup2 = torch.max(pred_train_unsup2, dim=1)[1]\n            max_train_unsup1 = max_train_unsup1.long()\n            max_train_unsup2 = max_train_unsup2.long()\n\n            loss_train_unsup = criterion(pred_train_unsup1, max_train_unsup2) + criterion(pred_train_unsup2, max_train_unsup1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train_unsup.backward(retain_graph=True)\n            torch.cuda.empty_cache()\n\n\n            sup_index = next(dataset_train_sup)\n            img_train_sup_1 = Variable(sup_index['image'][tio.DATA].cuda())\n            img_train_sup_2 = Variable(sup_index['image2'][tio.DATA].cuda())\n            mask_train_sup = Variable(sup_index['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            pred_train_sup1, pred_train_sup2 = model(img_train_sup_1, img_train_sup_2)\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = pred_train_sup1\n                    score_list_train2 = pred_train_sup2\n                    mask_list_train = mask_train_sup\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train1 = torch.cat((score_list_train1, pred_train_sup1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, pred_train_sup2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train_sup), dim=0)\n\n            loss_train_sup1 = criterion(pred_train_sup1, mask_train_sup)\n            loss_train_sup2 = criterion(pred_train_sup2, mask_train_sup)\n\n            loss_train_sup = loss_train_sup1 + loss_train_sup2\n            loss_train_sup.backward()\n            optimizer.step()\n            torch.cuda.empty_cache()\n\n            loss_train = loss_train_unsup + loss_train_sup\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer.zero_grad()\n                    outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_XNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_itn\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=4, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-ds', '--deep_supervision', default=False)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='unet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n            os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))\n        visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)\n\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    # Dataset\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n\n    dataset_train = imagefloder_itn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n    dataset_val = imagefloder_itn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss = 0.0\n        val_loss = 0.0\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train = Variable(data['image'].cuda())\n            mask_train = Variable(data['mask'].cuda())\n\n            optimizer.zero_grad()\n            outputs_train = model(inputs_train)\n            torch.cuda.empty_cache()\n\n            if args.deep_supervision:\n                loss_train = 0\n                for output_train in outputs_train:\n                    loss_train += criterion(output_train, mask_train)\n                loss_train /= len(outputs_train)\n                outputs_train = outputs_train[0]\n            else:\n                loss_train = criterion(outputs_train, mask_train)\n\n            loss_train.backward()\n            optimizer.step()\n            train_loss += loss_train.item()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train = outputs_train\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 4:\n                    score_list_train = torch.cat((score_list_train, outputs_train), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train, score_list_train)\n            score_list_train = torch.cat(score_gather_list_train, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val = Variable(data['image'].cuda())\n                    mask_val = Variable(data['mask'].cuda())\n                    name_val = data['ID']\n\n                    optimizer.zero_grad()\n                    outputs_val = model(inputs_val)\n                    torch.cuda.empty_cache()\n\n                    if args.deep_supervision:\n                        loss_val = 0\n                        for output_val in outputs_val:\n                            loss_val += criterion(output_val, mask_val)\n                        loss_val /= len(outputs_val)\n                        outputs_val = outputs_val[0]\n                    else:\n                        loss_val = criterion(outputs_val, mask_val)\n                    val_loss += loss_val.item()\n\n                    if i == 0:\n                        score_list_val = outputs_val\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val = torch.cat((score_list_val, outputs_val), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                torch.cuda.empty_cache()\n                score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val, score_list_val)\n                score_list_val = torch.cat(score_gather_list_val, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)\n                    val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)\n                        visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_it\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.005, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='vnet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    path_trained_models = args.path_trained_models+'/'+str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s=' + str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results+'/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration)+'-'+str(args.sup_mark))\n        visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)\n\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_sup = dataset_it(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=None\n    )\n    dataset_val = dataset_it(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss = 0.0\n        val_loss = 0.0\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train = Variable(data['image'][tio.DATA].cuda())\n            mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            optimizer.zero_grad()\n            outputs_train = model(inputs_train)\n\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train = outputs_train\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train = torch.cat((score_list_train, outputs_train), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n            loss_train = criterion(outputs_train, mask_train)\n            loss_train.backward()\n            optimizer.step()\n            train_loss += loss_train.item()\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train, score_list_train)\n            score_list_train = torch.cat(score_gather_list_train, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val = Variable(data['image'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer.zero_grad()\n                    outputs_val = model(inputs_val)\n\n                    torch.cuda.empty_cache()\n                    if i == 0:\n                        score_list_val = outputs_val\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val = torch.cat((score_list_val, outputs_val), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n\n                    loss_val = criterion(outputs_val, mask_val)\n                    val_loss += loss_val.item()\n\n                torch.cuda.empty_cache()\n\n                score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val, score_list_val)\n                score_list_val = torch.cat(score_gather_list_val, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)\n                    val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_sup(visdom, epoch + 1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_ConResNet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_ConResNet, print_val_loss_ConResNet, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_3d, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_ConResNet, visualization_ConResNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_iit_conresnet\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='image')\n    parser.add_argument('--input2', default='image_res')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.1, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='conresnet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-ConResNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))\n        visdom = visdom_initialization_ConResNet(env=visdom_env, port=args.visdom_port)\n\n\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_sup = dataset_iit_conresnet(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=None,\n    )\n    dataset_val = dataset_iit_conresnet(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'], img_shape=args.patch_size)\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)\n\n    # Training Strategy\n    criterion_dice = segmentation_loss('dice', False).cuda()\n    criterion_ce = segmentation_loss('CE', False).cuda()\n    criterion_bound = segmentation_loss('bcebound', False, num_classes=cfg['NUM_CLASSES']).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss_seg = 0.0\n        train_loss_res = 0.0\n        train_loss = 0.0\n        val_loss_seg = 0.0\n        val_loss_res = 0.0\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train_1 = Variable(data['image'][tio.DATA].cuda())\n            inputs_train_2 = Variable(data['image2'][tio.DATA].cuda())\n            mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n            mask_train_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda())\n\n            optimizer.zero_grad()\n            outputs_train = model(inputs_train_1, inputs_train_2)\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train = outputs_train[0]\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train = torch.cat((score_list_train, outputs_train[0]), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n\n            loss_train_seg = criterion_dice(outputs_train[0], mask_train) + criterion_ce(outputs_train[0], mask_train)\n            loss_train_res = criterion_bound(outputs_train[1], mask_train_2) + 0.5 * (criterion_bound(outputs_train[2], mask_train_2) + criterion_bound(outputs_train[3], mask_train_2))\n            loss_train = loss_train_seg + loss_train_res\n            loss_train.backward()\n            optimizer.step()\n\n            train_loss_seg += loss_train_seg.item()\n            train_loss_res += loss_train_res.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train, score_list_train)\n            score_list_train = torch.cat(score_gather_list_train, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss = print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus)\n                train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val = Variable(data['image'][tio.DATA].cuda())\n                    inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n                    mask_val_2 = Variable(data['mask2'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer.zero_grad()\n                    outputs_val = model(inputs_val, inputs_val_2)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val = outputs_val[0]\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val = torch.cat((score_list_val, outputs_val[0]), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_seg = criterion_dice(outputs_val[0], mask_val) + criterion_ce(outputs_val[0], mask_val)\n                    loss_val_res = criterion_bound(outputs_val[1], mask_val_2) + 0.5 * (criterion_bound(outputs_val[2], mask_val_2) + criterion_bound(outputs_val[3], mask_val_2))\n\n                    val_loss_seg += loss_val_seg.item()\n                    val_loss_res += loss_val_res.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val, score_list_val)\n                score_list_val = torch.cat(score_gather_list_val, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_seg, val_epoch_loss_res = print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half)\n                    val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_3d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, mask_list_val, val_eval_list, path_trained_models, path_seg_results, path_mask_results, args.network, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_ConResNet(visdom, epoch + 1, train_epoch_loss, train_epoch_loss_seg, train_epoch_loss_res, train_m_jc, val_epoch_loss_seg, val_epoch_loss_res, val_m_jc)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_XNet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_iitnn\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size)+ '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    if args.input2 == 'image':\n        input2_mean = 'MEAN'\n        input2_std = 'STD'\n    else:\n        input2_mean = 'MEAN_' + args.input2\n        input2_std = 'STD_' + args.input2\n\n    data_transforms = data_transform_2d()\n    data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n    data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])\n\n    dataset_train = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n    dataset_val = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train_1 = Variable(data['image'].cuda())\n            inputs_train_2 = Variable(data['image_2'].cuda())\n            mask_train = Variable(data['mask'].cuda())\n\n            optimizer.zero_grad()\n            outputs_train1, outputs_train2 = model(inputs_train_1, inputs_train_2)\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = outputs_train1\n                    score_list_train2 = outputs_train2\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 4:\n                    score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n            max_train1 = torch.max(outputs_train1, dim=1)[1]\n            max_train2 = torch.max(outputs_train2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_sup1 = criterion(outputs_train1, mask_train)\n            loss_train_sup2 = criterion(outputs_train2, mask_train)\n            loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup\n\n            loss_train.backward()\n            optimizer.step()\n\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val = Variable(data['image'].cuda())\n                    inputs_val_wavelet = Variable(data['image_2'].cuda())\n                    mask_val = Variable(data['mask'].cuda())\n                    name_val = data['ID']\n\n                    optimizer.zero_grad()\n                    outputs_val1, outputs_val2 = model(inputs_val, inputs_val_wavelet)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)\n                        visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_XNet3d.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\nimport torchio as tio\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_3d, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_3d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_3d import dataset_iit\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/Atrial')\n    parser.add_argument('--dataset_name', default='Atrial', help='LiTS, Atrial')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--sup_mark', default='100', help='100')\n    parser.add_argument('-b', '--batch_size', default=1, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('--patch_size', default=(96, 96, 80))\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n    parser.add_argument('--queue_length', default=48, type=int)\n    parser.add_argument('--samples_per_volume_train', default=4, type=int)\n    parser.add_argument('--samples_per_volume_val', default=8, type=int)\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet3d', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_mask_results = path_seg_results + '/mask'\n    if not os.path.exists(path_mask_results) and rank == args.rank_index:\n        os.mkdir(path_mask_results)\n    path_seg_results_1 = path_seg_results + '/pred'\n    if not os.path.exists(path_seg_results_1) and rank == args.rank_index:\n        os.mkdir(path_seg_results_1)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-XNet-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-cw='+str(args.unsup_weight)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))+str(args.input1)+'-'+str(args.input2)\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transform = data_transform_3d(cfg['NORMALIZE'])\n\n    dataset_train_sup = dataset_iit(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['train'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_train,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=True,\n        shuffle_patches=True,\n        sup=True,\n        num_images=None\n    )\n    dataset_val = dataset_iit(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        transform_1=data_transform['val'],\n        queue_length=args.queue_length,\n        samples_per_volume=args.samples_per_volume_val,\n        patch_size=args.patch_size,\n        num_workers=8,\n        shuffle_subjects=False,\n        shuffle_patches=False,\n        sup=True,\n        num_images=None\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train_sup.queue_train_set_1, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val.queue_train_set_1, shuffle=False)\n\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train_sup.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val.queue_train_set_1, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train_1 = Variable(data['image'][tio.DATA].cuda())\n            inputs_train_2 = Variable(data['image2'][tio.DATA].cuda())\n            mask_train = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n            optimizer.zero_grad()\n            outputs_train_1, outputs_train_2 = model(inputs_train_1, inputs_train_2)\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train_1 = outputs_train_1\n                    score_list_train_2 = outputs_train_2\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 32:\n                    score_list_train_1 = torch.cat((score_list_train_1, outputs_train_1), dim=0)\n                    score_list_train_2 = torch.cat((score_list_train_2, outputs_train_2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n            max_train1 = torch.max(outputs_train_1, dim=1)[1]\n            max_train2 = torch.max(outputs_train_2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_sup1 = criterion(outputs_train_1, mask_train)\n            loss_train_sup2 = criterion(outputs_train_2, mask_train)\n            loss_train_unsup = criterion(outputs_train_1, max_train2) + criterion(outputs_train_2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup\n\n            loss_train.backward()\n            optimizer.step()\n\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train_1 = [torch.zeros_like(score_list_train_1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train_1, score_list_train_1)\n            score_list_train_1 = torch.cat(score_gather_list_train_1, dim=0)\n\n            score_gather_list_train_2 = [torch.zeros_like(score_list_train_2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train_2, score_list_train_2)\n            score_list_train_2 = torch.cat(score_gather_list_train_2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list_1, train_eval_list_2, train_m_jc_1, train_m_jc_2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train_1, score_list_train_2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'][tio.DATA].cuda())\n                    inputs_val_2 = Variable(data['image2'][tio.DATA].cuda())\n                    mask_val = Variable(data['mask'][tio.DATA].squeeze(1).long().cuda())\n\n                    optimizer.zero_grad()\n                    outputs_val_1, outputs_val_2 = model(inputs_val_1, inputs_val_2)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val_1 = outputs_val_1\n                        score_list_val_2 = outputs_val_2\n                        mask_list_val = mask_val\n                    else:\n                        score_list_val_1 = torch.cat((score_list_val_1, outputs_val_1), dim=0)\n                        score_list_val_2 = torch.cat((score_list_val_2, outputs_val_2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n\n                    loss_val_sup_1 = criterion(outputs_val_1, mask_val)\n                    loss_val_sup_2 = criterion(outputs_val_2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup_1.item()\n                    val_loss_sup_2 += loss_val_sup_2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val_1 = [torch.zeros_like(score_list_val_1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_1, score_list_val_1)\n                score_list_val_1 = torch.cat(score_gather_list_val_1, dim=0)\n\n                score_gather_list_val_2 = [torch.zeros_like(score_list_val_2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val_2, score_list_val_2)\n                score_list_val_2 = torch.cat(score_gather_list_val_2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup_1, val_epoch_loss_sup_2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list_1, val_eval_list_2, val_m_jc_1, val_m_jc_2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val_1, score_list_val_2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_3d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model, model, score_list_val_1, score_list_val_2, mask_list_val, val_eval_list_1, val_eval_list_2, path_trained_models, path_seg_results, path_mask_results, cfg['FORMAT'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup_1, train_epoch_loss_sup_2, train_epoch_loss_cps, train_m_jc_1, train_m_jc_2, val_epoch_loss_sup_1, val_epoch_loss_sup_2, val_m_jc_1, val_m_jc_2)\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_XNet_sb.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best\nfrom config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_iitnn\nfrom warnings import simplefilter\n\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/XNet/checkpoints/sup_xnet')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/XNet/seg_pred/sup_xnet')\n    parser.add_argument('--path_dataset', default='/mnt/data1/XNet/dataset/GlaS')\n    parser.add_argument('--dataset_name', default='GlaS', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--input1', default='L')\n    parser.add_argument('--input2', default='H')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=2, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('-u', '--unsup_weight', default=5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='xnet_sb', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14\n    print_num_minus = print_num - 2\n    print_num_half = int(print_num / 2 - 1)\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results + '/' + str(args.network) + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-XNet-' + str(os.path.split(args.path_dataset)[1]) + '-' + args.network + '-l=' + str(args.lr) + '-e=' + str(args.num_epochs) + '-s=' + str(args.step_size) + '-g=' + str(args.gamma) + '-b=' + str(args.batch_size) + '-cw=' + str(args.unsup_weight) + '-w=' + str(args.warm_up_duration) + '-' + str(args.sup_mark)+'-'+str(args.input1)+'-'+str(args.input2))\n        visdom = visdom_initialization_XNet(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    if args.input1 == 'image':\n        input1_mean = 'MEAN'\n        input1_std = 'STD'\n    else:\n        input1_mean = 'MEAN_' + args.input1\n        input1_std = 'STD_' + args.input1\n\n    if args.input2 == 'image':\n        input2_mean = 'MEAN'\n        input2_std = 'STD'\n    else:\n        input2_mean = 'MEAN_' + args.input2\n        input2_std = 'STD_' + args.input2\n\n    data_transforms = data_transform_2d()\n    data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])\n    data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])\n\n    dataset_train = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['train'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n    dataset_val = imagefloder_iitnn(\n        data_dir=args.path_dataset + '/val',\n        input1=args.input1,\n        input2=args.input2,\n        data_transform_1=data_transforms['val'],\n        data_normalize_1=data_normalize_1,\n        data_normalize_2=data_normalize_2,\n        sup=True,\n        num_images=None,\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model1 = get_network(args.network, 3, cfg['NUM_CLASSES'])\n    model2 = get_network(args.network, 1, cfg['NUM_CLASSES'])\n\n    model1 = model1.cuda()\n    model2 = model2.cuda()\n    model1 = DistributedDataParallel(model1, device_ids=[args.local_rank])\n    model2 = DistributedDataParallel(model2, device_ids=[args.local_rank])\n    dist.barrier()\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer1 = optim.SGD(model1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler1 = lr_scheduler.StepLR(optimizer1, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup1 = GradualWarmupScheduler(optimizer1, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler1)\n\n    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10 ** args.wd)\n    exp_lr_scheduler2 = lr_scheduler.StepLR(optimizer2, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup2 = GradualWarmupScheduler(optimizer2, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler2)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n\n    best_model = model1\n    best_result = 'Result1'\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter - 1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model1.train()\n        model2.train()\n\n        train_loss_sup_1 = 0.0\n        train_loss_sup_2 = 0.0\n        train_loss_unsup = 0.0\n        train_loss = 0.0\n        val_loss_sup_1 = 0.0\n        val_loss_sup_2 = 0.0\n\n        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs\n\n        dist.barrier()\n\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train_1 = Variable(data['image'].cuda())\n            inputs_train_2 = Variable(data['image_2'].cuda())\n            mask_train = Variable(data['mask'].cuda())\n\n            optimizer1.zero_grad()\n            optimizer2.zero_grad()\n\n            outputs_train1 = model1(inputs_train_1)\n            outputs_train2 = model2(inputs_train_2)\n            torch.cuda.empty_cache()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train1 = outputs_train1\n                    score_list_train2 = outputs_train2\n                    mask_list_train = mask_train\n                # else:\n                elif 0 < i <= num_batches['train_sup'] / 4:\n                    score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)\n                    score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n            max_train1 = torch.max(outputs_train1, dim=1)[1]\n            max_train2 = torch.max(outputs_train2, dim=1)[1]\n            max_train1 = max_train1.long()\n            max_train2 = max_train2.long()\n\n            loss_train_sup1 = criterion(outputs_train1, mask_train)\n            loss_train_sup2 = criterion(outputs_train2, mask_train)\n            loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)\n            loss_train_unsup = loss_train_unsup * unsup_weight\n            loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup\n\n            loss_train.backward()\n            optimizer1.step()\n            optimizer2.step()\n\n            train_loss_sup_1 += loss_train_sup1.item()\n            train_loss_sup_2 += loss_train_sup2.item()\n            train_loss_unsup += loss_train_unsup.item()\n            train_loss += loss_train.item()\n\n        scheduler_warmup1.step()\n        scheduler_warmup2.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train1 = [torch.zeros_like(score_list_train1) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train1, score_list_train1)\n            score_list_train1 = torch.cat(score_gather_list_train1, dim=0)\n\n            score_gather_list_train2 = [torch.zeros_like(score_list_train2) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train2, score_list_train2)\n            score_list_train2 = torch.cat(score_gather_list_train2, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num, print_num_half)\n                train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model1.eval()\n                model2.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_1 = Variable(data['image'].cuda())\n                    inputs_val_2 = Variable(data['image_2'].cuda())\n                    mask_val = Variable(data['mask'].cuda())\n                    name_val = data['ID']\n\n                    optimizer1.zero_grad()\n                    optimizer2.zero_grad()\n\n                    outputs_val1 = model1(inputs_val_1)\n                    outputs_val2 = model2(inputs_val_2)\n                    torch.cuda.empty_cache()\n\n                    if i == 0:\n                        score_list_val1 = outputs_val1\n                        score_list_val2 = outputs_val2\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)\n                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                    loss_val_sup1 = criterion(outputs_val1, mask_val)\n                    loss_val_sup2 = criterion(outputs_val2, mask_val)\n\n                    val_loss_sup_1 += loss_val_sup1.item()\n                    val_loss_sup_2 += loss_val_sup2.item()\n\n                torch.cuda.empty_cache()\n                score_gather_list_val1 = [torch.zeros_like(score_list_val1) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val1, score_list_val1)\n                score_list_val1 = torch.cat(score_gather_list_val1, dim=0)\n\n                score_gather_list_val2 = [torch.zeros_like(score_list_val2) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val2, score_list_val2)\n                score_list_val2 = torch.cat(score_gather_list_val2, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half)\n                    val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'], score_list_val1, score_list_val2, mask_list_val, print_num_half)\n                    best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model, best_val_eval_list, best_result, model1, model2, score_list_val1, score_list_val2, name_list_val, val_eval_list1, val_eval_list2, path_trained_models, path_seg_results, cfg['PALETTE'])\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_XNet(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train1, outputs_train2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2)\n                        visualization_XNet(visdom, epoch+1, train_epoch_loss, train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_m_jc1, train_m_jc2, val_epoch_loss_sup1, val_epoch_loss_sup2, val_m_jc1, val_m_jc2)\n                        visual_image_XNet(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3], draw_img[4], draw_img[5])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best(cfg['NUM_CLASSES'], best_val_eval_list, best_model, best_result, path_trained_models, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_alnet.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d, data_transform_aerial_lanenet\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_aerial_lanenet\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup')\n    parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--sup_mark', default='100', help='20, 100')\n    parser.add_argument('-b', '--batch_size', default=32, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='alnet', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n            os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))\n        visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)\n\n\n    # Dataset\n    data_transforms = data_transform_2d()\n    data_normalize = data_normalize_2d(cfg['MEAN'], cfg['STD'])\n    data_normalize_l1 = data_transform_aerial_lanenet(64, 64)\n    data_normalize_l2 = data_transform_aerial_lanenet(32, 32)\n    data_normalize_l3 = data_transform_aerial_lanenet(16, 16)\n    data_normalize_l4 = data_transform_aerial_lanenet(8, 8)\n\n    dataset_train = imagefloder_aerial_lanenet(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        data_transform=data_transforms['train'],\n        data_normalize=data_normalize,\n        data_normalize_l1=data_normalize_l1,\n        data_normalize_l2=data_normalize_l2,\n        data_normalize_l3=data_normalize_l3,\n        data_normalize_l4=data_normalize_l4\n    )\n    dataset_val = imagefloder_aerial_lanenet(\n        data_dir=args.path_dataset + '/val',\n        data_transform=data_transforms['val'],\n        data_normalize=data_normalize,\n        data_normalize_l1=data_normalize_l1,\n        data_normalize_l2=data_normalize_l2,\n        data_normalize_l3=data_normalize_l3,\n        data_normalize_l4=data_normalize_l4\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss = 0.0\n        val_loss = 0.0\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train = Variable(data['image'].cuda())\n            inputs_train_l1 = Variable(data['image_l1'].cuda())\n            inputs_train_l2 = Variable(data['image_l2'].cuda())\n            inputs_train_l3 = Variable(data['image_l3'].cuda())\n            inputs_train_l4 = Variable(data['image_l4'].cuda())\n            mask_train = Variable(data['mask'].cuda())\n\n            optimizer.zero_grad()\n            outputs_train = model(inputs_train, inputs_train_l1, inputs_train_l2, inputs_train_l3, inputs_train_l4)\n            torch.cuda.empty_cache()\n\n            loss_train = criterion(outputs_train, mask_train)\n            loss_train.backward()\n            optimizer.step()\n            train_loss += loss_train.item()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train = outputs_train\n                    mask_list_train = mask_train\n                else:\n                # elif 0 < i <= num_batches['train_sup'] / 16:\n                    score_list_train = torch.cat((score_list_train, outputs_train), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train, score_list_train)\n            score_list_train = torch.cat(score_gather_list_train, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val = Variable(data['image'].cuda())\n                    inputs_val_l1 = Variable(data['image_l1'].cuda())\n                    inputs_val_l2 = Variable(data['image_l2'].cuda())\n                    inputs_val_l3 = Variable(data['image_l3'].cuda())\n                    inputs_val_l4 = Variable(data['image_l4'].cuda())\n                    mask_val = Variable(data['mask'].cuda())\n                    name_val = data['ID']\n\n                    optimizer.zero_grad()\n                    outputs_val = model(inputs_val, inputs_val_l1, inputs_val_l2, inputs_val_l3, inputs_val_l4)\n                    torch.cuda.empty_cache()\n\n                    loss_val = criterion(outputs_val, mask_val)\n                    val_loss += loss_val.item()\n\n                    if i == 0:\n                        score_list_val = outputs_val\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val = torch.cat((score_list_val, outputs_val), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                torch.cuda.empty_cache()\n                score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val, score_list_val)\n                score_list_val = torch.cat(score_gather_list_val, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)\n                    val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)\n                        visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  },
  {
    "path": "train_sup_wds.py",
    "content": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nimport argparse\nimport time\nimport os\nimport numpy as np\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.backends import cudnn\nimport random\n\nfrom config.dataset_config.dataset_cfg import dataset_cfg\nfrom config.train_test_config.train_test_config import print_train_loss_sup, print_val_loss_sup, print_train_eval_sup, print_val_eval_sup, save_val_best_sup_2d, draw_pred_sup, print_best_sup\nfrom config.visdom_config.visual_visdom import visdom_initialization_sup, visualization_sup, visual_image_sup\nfrom config.warmup_config.warmup import GradualWarmupScheduler\nfrom config.augmentation.online_aug import data_transform_2d, data_normalize_2d\nfrom loss.loss_function import segmentation_loss\nfrom models.getnetwork import get_network\nfrom dataload.dataset_2d import imagefloder_wds\nfrom warnings import simplefilter\nsimplefilter(action='ignore', category=FutureWarning)\n\n\ndef init_seeds(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(0)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--path_trained_models', default='/mnt/data1/GeYang_shared/XNet/checkpoints/sup')\n    parser.add_argument('--path_seg_results', default='/mnt/data1/GeYang_shared/XNet/seg_pred/sup')\n    parser.add_argument('--path_dataset', default='/mnt/data1/GeYang_shared/XNet/dataset/CREMI')\n    parser.add_argument('--dataset_name', default='CREMI', help='CREMI, ISIC-2017, GlaS')\n    parser.add_argument('--sup_mark', default='100')\n    parser.add_argument('-b', '--batch_size', default=32, type=int)\n    parser.add_argument('-e', '--num_epochs', default=200, type=int)\n    parser.add_argument('-s', '--step_size', default=50, type=int)\n    parser.add_argument('-l', '--lr', default=0.5, type=float)\n    parser.add_argument('-g', '--gamma', default=0.5, type=float)\n    parser.add_argument('--loss', default='dice', type=str)\n    parser.add_argument('-w', '--warm_up_duration', default=20)\n    parser.add_argument('--momentum', default=0.9, type=float)\n    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')\n\n    parser.add_argument('-i', '--display_iter', default=5, type=int)\n    parser.add_argument('-n', '--network', default='wds', type=str)\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--rank_index', default=0, help='0, 1, 2, 3')\n    parser.add_argument('-v', '--vis', default=True, help='need visualization or not')\n    parser.add_argument('--visdom_port', default=16672, help='16672')\n    args = parser.parse_args()\n\n    torch.cuda.set_device(args.local_rank)\n    dist.init_process_group(backend='nccl', init_method='env://')\n\n    rank = torch.distributed.get_rank()\n    ngpus_per_node = torch.cuda.device_count()\n    init_seeds(rank + 1)\n\n    dataset_name = args.dataset_name\n    cfg = dataset_cfg(dataset_name)\n\n    print_num = 42 + (cfg['NUM_CLASSES'] - 3) * 7\n    print_num_minus = print_num - 2\n\n    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n        os.mkdir(path_trained_models)\n    path_trained_models = path_trained_models+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_trained_models) and rank == args.rank_index:\n            os.mkdir(path_trained_models)\n\n    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n    path_seg_results = path_seg_results+'/'+str(args.network)+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark)\n    if not os.path.exists(path_seg_results) and rank == args.rank_index:\n        os.mkdir(path_seg_results)\n\n    if args.vis and rank == args.rank_index:\n        visdom_env = str('Sup-'+str(os.path.split(args.path_dataset)[1])+'-'+args.network+'-l='+str(args.lr)+'-e='+str(args.num_epochs)+'-s='+str(args.step_size)+'-g='+str(args.gamma)+'-b='+str(args.batch_size)+'-w='+str(args.warm_up_duration)+'-'+str(args.sup_mark))\n        visdom = visdom_initialization_sup(env=visdom_env, port=args.visdom_port)\n\n    # Dataset\n    data_transforms = data_transform_2d()\n    data_normalize_LL = data_normalize_2d(cfg['MEAN_LL'], cfg['STD_LL'])\n    data_normalize_LH = data_normalize_2d(cfg['MEAN_LH'], cfg['STD_LH'])\n    data_normalize_HL = data_normalize_2d(cfg['MEAN_HL'], cfg['STD_HL'])\n    data_normalize_HH = data_normalize_2d(cfg['MEAN_HH'], cfg['STD_HH'])\n\n    dataset_train = imagefloder_wds(\n        data_dir=args.path_dataset + '/train_sup_' + args.sup_mark,\n        data_transform_1=data_transforms['train'],\n        data_normalize_LL=data_normalize_LL,\n        data_normalize_LH=data_normalize_LH,\n        data_normalize_HL=data_normalize_HL,\n        data_normalize_HH=data_normalize_HH\n    )\n    dataset_val = imagefloder_wds(\n        data_dir=args.path_dataset + '/val',\n        data_transform_1=data_transforms['val'],\n        data_normalize_LL=data_normalize_LL,\n        data_normalize_LH=data_normalize_LH,\n        data_normalize_HL=data_normalize_HL,\n        data_normalize_HH=data_normalize_HH\n    )\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True)\n    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)\n\n    dataloaders = dict()\n    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=train_sampler)\n    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, sampler=val_sampler)\n\n    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}\n\n    # Model\n    model = get_network(args.network, 1, cfg['NUM_CLASSES'])\n    model = model.cuda()\n    model = DistributedDataParallel(model, device_ids=[args.local_rank])\n\n    # Training Strategy\n    criterion = segmentation_loss(args.loss, False).cuda()\n\n    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)\n    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)\n    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)\n\n    # Train & Val\n    since = time.time()\n    count_iter = 0\n    best_val_eval_list = [0 for i in range(4)]\n\n    for epoch in range(args.num_epochs):\n\n        count_iter += 1\n        if (count_iter-1) % args.display_iter == 0:\n            begin_time = time.time()\n\n        dataloaders['train'].sampler.set_epoch(epoch)\n        model.train()\n\n        train_loss = 0.0\n        val_loss = 0.0\n\n        dist.barrier()\n        for i, data in enumerate(dataloaders['train']):\n\n            inputs_train_LL = Variable(data['image_LL'].cuda())\n            inputs_train_LH = Variable(data['image_LH'].cuda())\n            inputs_train_HL = Variable(data['image_HL'].cuda())\n            inputs_train_HH = Variable(data['image_HH'].cuda())\n            mask_train = Variable(data['mask'].cuda())\n\n            optimizer.zero_grad()\n            outputs_train = model(inputs_train_LL, inputs_train_LH, inputs_train_HL, inputs_train_HH)\n            torch.cuda.empty_cache()\n\n            loss_train = criterion(outputs_train, mask_train)\n\n            loss_train.backward()\n            optimizer.step()\n            train_loss += loss_train.item()\n\n            if count_iter % args.display_iter == 0:\n                if i == 0:\n                    score_list_train = outputs_train\n                    mask_list_train = mask_train\n                else:\n                # elif 0 < i <= num_batches['train_sup'] / 16:\n                    score_list_train = torch.cat((score_list_train, outputs_train), dim=0)\n                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)\n\n        scheduler_warmup.step()\n        torch.cuda.empty_cache()\n\n        if count_iter % args.display_iter == 0:\n\n            score_gather_list_train = [torch.zeros_like(score_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(score_gather_list_train, score_list_train)\n            score_list_train = torch.cat(score_gather_list_train, dim=0)\n\n            mask_gather_list_train = [torch.zeros_like(mask_list_train) for _ in range(ngpus_per_node)]\n            torch.distributed.all_gather(mask_gather_list_train, mask_list_train)\n            mask_list_train = torch.cat(mask_gather_list_train, dim=0)\n\n            if rank == args.rank_index:\n                torch.cuda.empty_cache()\n                print('=' * print_num)\n                print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')\n                train_epoch_loss = print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus)\n                train_eval_list, train_m_jc = print_train_eval_sup(cfg['NUM_CLASSES'], score_list_train, mask_list_train, print_num_minus)\n                torch.cuda.empty_cache()\n\n            with torch.no_grad():\n                model.eval()\n\n                for i, data in enumerate(dataloaders['val']):\n\n                    # if 0 <= i <= num_batches['val']:\n\n                    inputs_val_LL = Variable(data['image_LL'].cuda())\n                    inputs_val_LH = Variable(data['image_LH'].cuda())\n                    inputs_val_HL = Variable(data['image_HL'].cuda())\n                    inputs_val_HH = Variable(data['image_HH'].cuda())\n                    mask_val = Variable(data['mask'].cuda())\n                    name_val = data['ID']\n\n                    optimizer.zero_grad()\n                    outputs_val = model(inputs_val_LH, inputs_val_LH, inputs_val_HL, inputs_val_HH)\n                    torch.cuda.empty_cache()\n\n                    loss_val = criterion(outputs_val, mask_val)\n                    val_loss += loss_val.item()\n\n                    if i == 0:\n                        score_list_val = outputs_val\n                        mask_list_val = mask_val\n                        name_list_val = name_val\n                    else:\n                        score_list_val = torch.cat((score_list_val, outputs_val), dim=0)\n                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)\n                        name_list_val = np.append(name_list_val, name_val, axis=0)\n\n                torch.cuda.empty_cache()\n                score_gather_list_val = [torch.zeros_like(score_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(score_gather_list_val, score_list_val)\n                score_list_val = torch.cat(score_gather_list_val, dim=0)\n\n                mask_gather_list_val = [torch.zeros_like(mask_list_val) for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather(mask_gather_list_val, mask_list_val)\n                mask_list_val = torch.cat(mask_gather_list_val, dim=0)\n\n                name_gather_list_val = [None for _ in range(ngpus_per_node)]\n                torch.distributed.all_gather_object(name_gather_list_val, name_list_val)\n                name_list_val = np.concatenate(name_gather_list_val, axis=0)\n                torch.cuda.empty_cache()\n\n                if rank == args.rank_index:\n                    val_epoch_loss = print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus)\n                    val_eval_list, val_m_jc = print_val_eval_sup(cfg['NUM_CLASSES'], score_list_val, mask_list_val, print_num_minus)\n                    best_val_eval_list = save_val_best_sup_2d(cfg['NUM_CLASSES'], best_val_eval_list, model, score_list_val, name_list_val, val_eval_list, path_trained_models, path_seg_results, cfg['PALETTE'], args.network)\n                    torch.cuda.empty_cache()\n\n                    if args.vis:\n                        draw_img = draw_pred_sup(cfg['NUM_CLASSES'], mask_train, mask_val, outputs_train, outputs_val, train_eval_list, val_eval_list)\n                        visualization_sup(visdom, epoch+1, train_epoch_loss, train_m_jc, val_epoch_loss, val_m_jc)\n                        visual_image_sup(visdom, draw_img[0], draw_img[1], draw_img[2], draw_img[3])\n\n                    print('-' * print_num)\n                    print('| Epoch Time: {:.4f}s'.format((time.time() - begin_time) / args.display_iter).ljust(print_num_minus, ' '), '|')\n            torch.cuda.empty_cache()\n        torch.cuda.empty_cache()\n\n    if rank == args.rank_index:\n        time_elapsed = time.time() - since\n        m, s = divmod(time_elapsed, 60)\n        h, m = divmod(m, 60)\n\n        print('=' * print_num)\n        print('| Training Completed In {:.0f}h {:.0f}mins {:.0f}s'.format(h, m, s).ljust(print_num_minus, ' '), '|')\n        print('-' * print_num)\n        print_best_sup(cfg['NUM_CLASSES'], best_val_eval_list, print_num_minus)\n        print('=' * print_num)"
  }
]