Full Code of Yanfeng-Zhou/XNet for AI

main 58181894053f cached
108 files
1.1 MB
309.1k tokens
1024 symbols
1 requests
Download .txt
Showing preview only (1,132K chars total). Download the full file or copy to clipboard to get everything.
Repository: Yanfeng-Zhou/XNet
Branch: main
Commit: 58181894053f
Files: 108
Total size: 1.1 MB

Directory structure:
gitextract_k1uwhl9g/

├── .idea/
│   ├── XNet.iml
│   ├── deployment.xml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── vcs.xml
│   └── workspace.xml
├── LICENSE
├── README.md
├── config/
│   ├── __init__.py
│   ├── augmentation/
│   │   ├── __init__.py
│   │   └── online_aug.py
│   ├── dataset_config/
│   │   ├── __init__.py
│   │   └── dataset_cfg.py
│   ├── eval_config/
│   │   ├── __init__.py
│   │   └── eval.py
│   ├── ramps/
│   │   ├── __init__.py
│   │   └── ramps.py
│   ├── train_test_config/
│   │   ├── __init__.py
│   │   └── train_test_config.py
│   ├── visdom_config/
│   │   ├── __init__.py
│   │   └── visual_visdom.py
│   └── warmup_config/
│       ├── __init__.py
│       └── warmup.py
├── dataload/
│   ├── __init__.py
│   ├── dataset_2d.py
│   └── dataset_3d.py
├── loss/
│   ├── __init__.py
│   └── loss_function.py
├── models/
│   ├── __init__.py
│   ├── getnetwork.py
│   ├── networks_2d/
│   │   ├── __init__.py
│   │   ├── aerial_lanenet.py
│   │   ├── hrnet.py
│   │   ├── mwcnn.py
│   │   ├── resunet.py
│   │   ├── resunet_plusplus.py
│   │   ├── swinunet.py
│   │   ├── u2net.py
│   │   ├── unet.py
│   │   ├── unet_3plus.py
│   │   ├── unet_cct.py
│   │   ├── unet_plusplus.py
│   │   ├── unet_urpc.py
│   │   ├── wavesnet.py
│   │   ├── wds.py
│   │   └── xnet.py
│   └── networks_3d/
│       ├── __init__.py
│       ├── conresnet.py
│       ├── cotr.py
│       ├── dmfnet.py
│       ├── espnet3d.py
│       ├── res_unet3d.py
│       ├── transbts.py
│       ├── unet3d.py
│       ├── unet3d_cct.py
│       ├── unet3d_dtc.py
│       ├── unet3d_urpc.py
│       ├── unetr.py
│       ├── vnet.py
│       ├── vnet_cct.py
│       ├── vnet_dtc.py
│       └── xnet3d.py
├── requirements.txt
├── test.py
├── test_3d.py
├── test_ConResNet.py
├── test_DTC.py
├── test_xnet.py
├── test_xnet3d.py
├── tools/
│   ├── Atrial/
│   │   ├── __init__.py
│   │   ├── postprocess.py
│   │   └── preprocess.py
│   ├── LiTS/
│   │   ├── __init__.py
│   │   ├── postprocess.py
│   │   ├── preprocess.py
│   │   └── split_train_val.py
│   ├── __init__.py
│   ├── eval.py
│   ├── mask2sdf.py
│   ├── res_image_mask.py
│   ├── wavelet2D.py
│   └── wavelet3D.py
├── train_semi_CCT.py
├── train_semi_CCT_3d.py
├── train_semi_CPS.py
├── train_semi_CPS_3d.py
├── train_semi_CT.py
├── train_semi_CT_3d.py
├── train_semi_DTC.py
├── train_semi_EM.py
├── train_semi_EM_3d.py
├── train_semi_MT.py
├── train_semi_MT_3d.py
├── train_semi_UAMT.py
├── train_semi_UAMT_3d.py
├── train_semi_URPC.py
├── train_semi_URPC_3d.py
├── train_semi_XNet.py
├── train_semi_XNet3d.py
├── train_sup.py
├── train_sup_3d.py
├── train_sup_ConResNet.py
├── train_sup_XNet.py
├── train_sup_XNet3d.py
├── train_sup_XNet_sb.py
├── train_sup_alnet.py
└── train_sup_wds.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .idea/XNet.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
  <component name="NewModuleRootManager">
    <content url="file://$MODULE_DIR$" />
    <orderEntry type="jdk" jdkName="Python 3.7" jdkType="Python SDK" />
    <orderEntry type="sourceFolder" forTests="false" />
  </component>
  <component name="TestRunnerService">
    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
  </component>
</module>

================================================
FILE: .idea/deployment.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="PublishConfigData">
    <serverData>
      <paths name="EM1">
        <serverdata>
          <mappings>
            <mapping deploy="/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="EM2">
        <serverdata>
          <mappings>
            <mapping deploy="/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="GPU0">
        <serverdata>
          <mappings>
            <mapping deploy="/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="GPU4">
        <serverdata>
          <mappings>
            <mapping deploy="/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="GPU5">
        <serverdata>
          <mappings>
            <mapping deploy="/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="N22">
        <serverdata>
          <mappings>
            <mapping local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="N30">
        <serverdata>
          <mappings>
            <mapping deploy="/run/XNet" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
    </serverData>
  </component>
</project>

================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
<component name="InspectionProjectProfileManager">
  <settings>
    <option name="USE_PROJECT_PROFILE" value="false" />
    <version value="1.0" />
  </settings>
</component>

================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="JavaScriptSettings">
    <option name="languageLevel" value="ES6" />
  </component>
  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7" project-jdk-type="Python SDK" />
</project>

================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectModuleManager">
    <modules>
      <module fileurl="file://$PROJECT_DIR$/.idea/XNet.iml" filepath="$PROJECT_DIR$/.idea/XNet.iml" />
    </modules>
  </component>
</project>

================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="VcsDirectoryMappings">
    <mapping directory="$PROJECT_DIR$" vcs="Git" />
  </component>
</project>

================================================
FILE: .idea/workspace.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ChangeListManager">
    <list default="true" id="a6084e43-8a3d-4f65-bc86-b2baf8115879" name="Default Changelist" comment="">
      <change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
      <change afterPath="$PROJECT_DIR$/train_sup_alnet.py" afterDir="false" />
      <change afterPath="$PROJECT_DIR$/train_sup_wds.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/config/augmentation/online_aug.py" beforeDir="false" afterPath="$PROJECT_DIR$/config/augmentation/online_aug.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/dataload/dataset_2d.py" beforeDir="false" afterPath="$PROJECT_DIR$/dataload/dataset_2d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/figure/Figure 3 v11.png" beforeDir="false" />
      <change beforePath="$PROJECT_DIR$/figure/figure 1 v2.png" beforeDir="false" />
      <change beforePath="$PROJECT_DIR$/figure/figure 2.png" beforeDir="false" />
      <change beforePath="$PROJECT_DIR$/models/__init__.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/__init__.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/models/getnetwork.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/getnetwork.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/test_xnet.py" beforeDir="false" afterPath="$PROJECT_DIR$/test_xnet.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/test_xnet3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/test_xnet3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/tools/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/tools/eval.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/tools/wavelet2D.py" beforeDir="false" afterPath="$PROJECT_DIR$/tools/wavelet2D.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/tools/wavelet3D.py" beforeDir="false" afterPath="$PROJECT_DIR$/tools/wavelet3D.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CCT.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CCT.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CCT_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CCT_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CPS.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CPS.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CPS_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CPS_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CT.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CT.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_CT_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_CT_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_DTC.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_DTC.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_EM.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_EM.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_EM_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_EM_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_MT.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_MT.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_MT_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_MT_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_UAMT.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_UAMT.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_UAMT_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_UAMT_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_URPC.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_URPC.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_URPC_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_URPC_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_XNet.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_XNet.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_semi_XNet3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_semi_XNet3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup_3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_ConResNet.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup_ConResNet.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_XNet.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup_XNet.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_XNet3d.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup_XNet3d.py" afterDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_XNet_EM.py" beforeDir="false" />
      <change beforePath="$PROJECT_DIR$/train_sup_XNet_sb.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_sup_XNet_sb.py" afterDir="false" />
    </list>
    <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
    <option name="SHOW_DIALOG" value="false" />
    <option name="HIGHLIGHT_CONFLICTS" value="true" />
    <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
    <option name="LAST_RESOLUTION" value="IGNORE" />
  </component>
  <component name="FileTemplateManagerImpl">
    <option name="RECENT_TEMPLATES">
      <list>
        <option value="Python Script" />
      </list>
    </option>
  </component>
  <component name="Git.Settings">
    <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
  </component>
  <component name="ProjectId" id="2ASvhrldQmnvrIuEBF6ewe1Fctz" />
  <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
  <component name="PropertiesComponent">
    <property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
    <property name="WebServerToolWindowFactoryState" value="false" />
    <property name="last_opened_file_path" value="$PROJECT_DIR$" />
    <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
  </component>
  <component name="RecentsManager">
    <key name="MoveFile.RECENT_KEYS">
      <recent name="D:\Desktop\XNet\tools\LiTS" />
      <recent name="D:\Desktop\XNet\tools\Atrial" />
      <recent name="D:\Desktop\XNet\tools" />
      <recent name="D:\Desktop\XNet\models\2d_networks" />
      <recent name="D:\Desktop\XNet\tools\CREMI" />
    </key>
  </component>
  <component name="RunDashboard">
    <option name="ruleStates">
      <list>
        <RuleState>
          <option name="name" value="ConfigurationTypeDashboardGroupingRule" />
        </RuleState>
        <RuleState>
          <option name="name" value="StatusDashboardGroupingRule" />
        </RuleState>
      </list>
    </option>
  </component>
  <component name="RunManager" selected="Python.demo">
    <configuration name="demo" type="PythonConfigurationType" factoryName="Python" temporary="true">
      <module name="XNet" />
      <option name="INTERPRETER_OPTIONS" value="" />
      <option name="PARENT_ENVS" value="true" />
      <envs>
        <env name="PYTHONUNBUFFERED" value="1" />
      </envs>
      <option name="SDK_HOME" value="" />
      <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
      <option name="IS_MODULE_SDK" value="true" />
      <option name="ADD_CONTENT_ROOTS" value="true" />
      <option name="ADD_SOURCE_ROOTS" value="true" />
      <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
      <option name="SCRIPT_NAME" value="$PROJECT_DIR$/demo.py" />
      <option name="PARAMETERS" value="" />
      <option name="SHOW_COMMAND_LINE" value="false" />
      <option name="EMULATE_TERMINAL" value="false" />
      <option name="MODULE_MODE" value="false" />
      <option name="REDIRECT_INPUT" value="false" />
      <option name="INPUT_FILE" value="" />
      <method v="2" />
    </configuration>
    <configuration name="demo1" type="PythonConfigurationType" factoryName="Python" temporary="true">
      <module name="XNet" />
      <option name="INTERPRETER_OPTIONS" value="" />
      <option name="PARENT_ENVS" value="true" />
      <envs>
        <env name="PYTHONUNBUFFERED" value="1" />
      </envs>
      <option name="SDK_HOME" value="" />
      <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
      <option name="IS_MODULE_SDK" value="true" />
      <option name="ADD_CONTENT_ROOTS" value="true" />
      <option name="ADD_SOURCE_ROOTS" value="true" />
      <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
      <option name="SCRIPT_NAME" value="$PROJECT_DIR$/demo1.py" />
      <option name="PARAMETERS" value="" />
      <option name="SHOW_COMMAND_LINE" value="false" />
      <option name="EMULATE_TERMINAL" value="false" />
      <option name="MODULE_MODE" value="false" />
      <option name="REDIRECT_INPUT" value="false" />
      <option name="INPUT_FILE" value="" />
      <method v="2" />
    </configuration>
    <configuration name="hrnet" type="PythonConfigurationType" factoryName="Python" temporary="true">
      <module name="XNet" />
      <option name="INTERPRETER_OPTIONS" value="" />
      <option name="PARENT_ENVS" value="true" />
      <envs>
        <env name="PYTHONUNBUFFERED" value="1" />
      </envs>
      <option name="SDK_HOME" value="" />
      <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/models/networks_2d" />
      <option name="IS_MODULE_SDK" value="true" />
      <option name="ADD_CONTENT_ROOTS" value="true" />
      <option name="ADD_SOURCE_ROOTS" value="true" />
      <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
      <option name="SCRIPT_NAME" value="$PROJECT_DIR$/models/networks_2d/hrnet.py" />
      <option name="PARAMETERS" value="" />
      <option name="SHOW_COMMAND_LINE" value="false" />
      <option name="EMULATE_TERMINAL" value="false" />
      <option name="MODULE_MODE" value="false" />
      <option name="REDIRECT_INPUT" value="false" />
      <option name="INPUT_FILE" value="" />
      <method v="2" />
    </configuration>
    <configuration name="vis" type="PythonConfigurationType" factoryName="Python" temporary="true">
      <module name="XNet" />
      <option name="INTERPRETER_OPTIONS" value="" />
      <option name="PARENT_ENVS" value="true" />
      <envs>
        <env name="PYTHONUNBUFFERED" value="1" />
      </envs>
      <option name="SDK_HOME" value="" />
      <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/tools/Atrial" />
      <option name="IS_MODULE_SDK" value="true" />
      <option name="ADD_CONTENT_ROOTS" value="true" />
      <option name="ADD_SOURCE_ROOTS" value="true" />
      <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
      <option name="SCRIPT_NAME" value="$PROJECT_DIR$/tools/Atrial/vis.py" />
      <option name="PARAMETERS" value="" />
      <option name="SHOW_COMMAND_LINE" value="false" />
      <option name="EMULATE_TERMINAL" value="false" />
      <option name="MODULE_MODE" value="false" />
      <option name="REDIRECT_INPUT" value="false" />
      <option name="INPUT_FILE" value="" />
      <method v="2" />
    </configuration>
    <configuration name="xnet" type="PythonConfigurationType" factoryName="Python" temporary="true">
      <module name="XNet" />
      <option name="INTERPRETER_OPTIONS" value="" />
      <option name="PARENT_ENVS" value="true" />
      <envs>
        <env name="PYTHONUNBUFFERED" value="1" />
      </envs>
      <option name="SDK_HOME" value="" />
      <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/models/networks_2d" />
      <option name="IS_MODULE_SDK" value="true" />
      <option name="ADD_CONTENT_ROOTS" value="true" />
      <option name="ADD_SOURCE_ROOTS" value="true" />
      <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
      <option name="SCRIPT_NAME" value="$PROJECT_DIR$/models/networks_2d/xnet.py" />
      <option name="PARAMETERS" value="" />
      <option name="SHOW_COMMAND_LINE" value="false" />
      <option name="EMULATE_TERMINAL" value="false" />
      <option name="MODULE_MODE" value="false" />
      <option name="REDIRECT_INPUT" value="false" />
      <option name="INPUT_FILE" value="" />
      <method v="2" />
    </configuration>
    <recent_temporary>
      <list>
        <item itemvalue="Python.demo" />
        <item itemvalue="Python.hrnet" />
        <item itemvalue="Python.xnet" />
        <item itemvalue="Python.demo1" />
        <item itemvalue="Python.vis" />
      </list>
    </recent_temporary>
  </component>
  <component name="SvnConfiguration">
    <configuration />
  </component>
  <component name="TaskManager">
    <task active="true" id="Default" summary="Default task">
      <changelist id="a6084e43-8a3d-4f65-bc86-b2baf8115879" name="Default Changelist" comment="" />
      <created>1655015922020</created>
      <option name="number" value="Default" />
      <option name="presentableId" value="Default" />
      <updated>1655015922020</updated>
      <workItem from="1655015923192" duration="37990000" />
      <workItem from="1655341600741" duration="188507000" />
      <workItem from="1657943138956" duration="168310000" />
      <workItem from="1660184258941" duration="265576000" />
      <workItem from="1661862792193" duration="353353000" />
      <workItem from="1664521887562" duration="3255000" />
      <workItem from="1665197785788" duration="606000" />
      <workItem from="1665400082786" duration="606000" />
      <workItem from="1665932248031" duration="4039000" />
      <workItem from="1665987707938" duration="6000" />
      <workItem from="1665987723627" duration="1392000" />
      <workItem from="1665996697302" duration="1393000" />
      <workItem from="1666072441496" duration="745000" />
      <workItem from="1676863559005" duration="498000" />
      <workItem from="1690253128423" duration="1477000" />
      <workItem from="1690260775935" duration="683000" />
      <workItem from="1690263220286" duration="1342000" />
    </task>
    <servers />
  </component>
  <component name="TypeScriptGeneratedFilesManager">
    <option name="version" value="1" />
  </component>
  <component name="Vcs.Log.Tabs.Properties">
    <option name="TAB_STATES">
      <map>
        <entry key="MAIN">
          <value>
            <State>
              <option name="COLUMN_ORDER" />
            </State>
          </value>
        </entry>
      </map>
    </option>
  </component>
  <component name="com.intellij.coverage.CoverageDataManagerImpl">
    <SUITE FILE_PATH="coverage/XNet$wavelet_trans.coverage" NAME="wavelet_trans Coverage Results" MODIFIED="1656748230352" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/GlaS" />
    <SUITE FILE_PATH="coverage/XNet$res_image_mask.coverage" NAME="res_image_mask Coverage Results" MODIFIED="1663396533063" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$wavelet_trans__2_.coverage" NAME="wavelet_trans (2) Coverage Results" MODIFIED="1657178663899" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/ISIC-2017" />
    <SUITE FILE_PATH="coverage/XNet$postprocess.coverage" NAME="postprocess Coverage Results" MODIFIED="1663946921242" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/Atrial" />
    <SUITE FILE_PATH="coverage/XNet$demo1.coverage" NAME="demo1 Coverage Results" MODIFIED="1664027395517" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
    <SUITE FILE_PATH="coverage/XNet$postprocess__1_.coverage" NAME="postprocess (1) Coverage Results" MODIFIED="1663998882582" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/LiTS" />
    <SUITE FILE_PATH="coverage/XNet$xnet3d.coverage" NAME="xnet3d Coverage Results" MODIFIED="1662386984840" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$window_crop.coverage" NAME="window_crop Coverage Results" MODIFIED="1656508185058" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CREMI" />
    <SUITE FILE_PATH="coverage/XNet$mask2sdf.coverage" NAME="mask2sdf Coverage Results" MODIFIED="1663901733624" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$preprocess__2_.coverage" NAME="preprocess (2) Coverage Results" MODIFIED="1661672037231" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CBMI-CA" />
    <SUITE FILE_PATH="coverage/XNet$wavelet3D.coverage" NAME="wavelet3D Coverage Results" MODIFIED="1662713970023" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$split_train_val.coverage" NAME="split_train_val Coverage Results" MODIFIED="1662220162317" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CBMI-CA" />
    <SUITE FILE_PATH="coverage/XNet$resunet_plusplus.coverage" NAME="resunet_plusplus Coverage Results" MODIFIED="1660993443784" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$unetr.coverage" NAME="unetr Coverage Results" MODIFIED="1661761635389" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$vnet_cct.coverage" NAME="vnet_cct Coverage Results" MODIFIED="1663554055273" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$mask_color.coverage" NAME="mask_color Coverage Results" MODIFIED="1655357975942" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$mean_std.coverage" NAME="mean_std Coverage Results" MODIFIED="1657175540046" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CREMI" />
    <SUITE FILE_PATH="coverage/XNet$split_semi__1_.coverage" NAME="split_semi (1) Coverage Results" MODIFIED="1661613383049" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/Atrial" />
    <SUITE FILE_PATH="coverage/XNet$split_semi.coverage" NAME="split_semi Coverage Results" MODIFIED="1661940512952" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/ISIC-2017" />
    <SUITE FILE_PATH="coverage/XNet$16bitto8bit.coverage" NAME="16bitto8bit Coverage Results" MODIFIED="1655346328619" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$demo.coverage" NAME="demo Coverage Results" MODIFIED="1664615715686" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
    <SUITE FILE_PATH="coverage/XNet$unet3d_dtc.coverage" NAME="unet3d_dtc Coverage Results" MODIFIED="1663903842933" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$nnunet.coverage" NAME="nnunet Coverage Results" MODIFIED="1661924361071" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$kiunet__1_.coverage" NAME="kiunet (1) Coverage Results" MODIFIED="1661915451342" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$mean_std__2_.coverage" NAME="mean_std (2) Coverage Results" MODIFIED="1656504773540" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/ISIC-2017" />
    <SUITE FILE_PATH="coverage/XNet$preprocess__1_.coverage" NAME="preprocess (1) Coverage Results" MODIFIED="1662224730108" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/LiTS" />
    <SUITE FILE_PATH="coverage/XNet$mask_pro__1_.coverage" NAME="mask_pro (1) Coverage Results" MODIFIED="1656640935024" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/ISIC-2017" />
    <SUITE FILE_PATH="coverage/XNet$bound.coverage" NAME="bound Coverage Results" MODIFIED="1661946279003" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$cotr.coverage" NAME="cotr Coverage Results" MODIFIED="1661845197933" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$vis.coverage" NAME="vis Coverage Results" MODIFIED="1664027158868" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/Atrial" />
    <SUITE FILE_PATH="coverage/XNet$vnet_dtc.coverage" NAME="vnet_dtc Coverage Results" MODIFIED="1663903484888" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$wavelet_trans__1_.coverage" NAME="wavelet_trans (1) Coverage Results" MODIFIED="1656899007333" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CREMI" />
    <SUITE FILE_PATH="coverage/XNet$unet_3plus.coverage" NAME="unet_3plus Coverage Results" MODIFIED="1661000368219" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$nnunet3d.coverage" NAME="nnunet3d Coverage Results" MODIFIED="1661925089667" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$split_train_val__1_.coverage" NAME="split_train_val (1) Coverage Results" MODIFIED="1662225861524" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/LiTS" />
    <SUITE FILE_PATH="coverage/XNet$unet_urpc.coverage" NAME="unet_urpc Coverage Results" MODIFIED="1659324102938" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models" />
    <SUITE FILE_PATH="coverage/XNet$u2net.coverage" NAME="u2net Coverage Results" MODIFIED="1661000154317" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$espnet3d.coverage" NAME="espnet3d Coverage Results" MODIFIED="1661921901877" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$transbts.coverage" NAME="transbts Coverage Results" MODIFIED="1661848844623" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$kiunet3d.coverage" NAME="kiunet3d Coverage Results" MODIFIED="1662030942646" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$mask_pro.coverage" NAME="mask_pro Coverage Results" MODIFIED="1656516186672" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/GlaS" />
    <SUITE FILE_PATH="coverage/XNet$unsup_split.coverage" NAME="unsup_split Coverage Results" MODIFIED="1655464356868" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$dfmnet.coverage" NAME="dfmnet Coverage Results" MODIFIED="1661845257224" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$xnet_single_branch.coverage" NAME="xnet_single_branch Coverage Results" MODIFIED="1656514918760" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models" />
    <SUITE FILE_PATH="coverage/XNet$resunet.coverage" NAME="resunet Coverage Results" MODIFIED="1660993328964" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$conresnet.coverage" NAME="conresnet Coverage Results" MODIFIED="1662087140386" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$swinunet.coverage" NAME="swinunet Coverage Results" MODIFIED="1661924117865" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$xnet.coverage" NAME="xnet Coverage Results" MODIFIED="1664431291705" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$vae_seg.coverage" NAME="vae_seg Coverage Results" MODIFIED="1663923536933" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_3d" />
    <SUITE FILE_PATH="coverage/XNet$split_dataset.coverage" NAME="split_dataset Coverage Results" MODIFIED="1662226197150" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools" />
    <SUITE FILE_PATH="coverage/XNet$vis__1_.coverage" NAME="vis (1) Coverage Results" MODIFIED="1663742524885" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/LiTS" />
    <SUITE FILE_PATH="coverage/XNet$preprocess.coverage" NAME="preprocess Coverage Results" MODIFIED="1662220017369" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/CBMI-CA" />
    <SUITE FILE_PATH="coverage/XNet$mean_std__1_.coverage" NAME="mean_std (1) Coverage Results" MODIFIED="1657175481052" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/tools/ISIC-2017" />
    <SUITE FILE_PATH="coverage/XNet$kiunet.coverage" NAME="kiunet Coverage Results" MODIFIED="1661933613759" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$hrnet.coverage" NAME="hrnet Coverage Results" MODIFIED="1664431322196" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models/networks_2d" />
    <SUITE FILE_PATH="coverage/XNet$unet_cct.coverage" NAME="unet_cct Coverage Results" MODIFIED="1659324460896" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/models" />
  </component>
</project>

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2024 Yanfeng Zhou

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================

# XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images

This is the official code of [XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images](https://openaccess.thecvf.com/content/ICCV2023/html/Zhou_XNet_Wavelet-Based_Low_and_High_Frequency_Fusion_Networks_for_Fully-_ICCV_2023_paper.html) (ICCV 2023).

## Overview
<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Architecture%20of%20XNet.png" width="100%" ></img>
<br>Architecture of XNet.
</p>
<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/visualize%20LF%20and%20HF%20images.png" width="100%" ></img>
<br>Visualize dual-branch inputs. (a) Raw image. (b) Wavelet transform results. (c) Low frequency image. (d) High frequency image.
</p>

<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Architecture%20of%20LF%20and%20HF%20fusion%20module.png" width="50%" ></img>
<br>Architecture of LF and HF fusion module.
</p>


## Quantitative Comparison

Comparison with fully- and semi-supervised state-of-the-art models on GlaS and CREMI test set. Semi-supervised models are based on UNet. DS indicates deep supervision. * indicates lightweight models. ‡ indicates training for 1000 epochs. - indicates training failed. <font color="Red">**Red**</font> and **bold** indicate the best and second best performance.

<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Comparison%20results%20on%20GlaS%20and%20CREMI.png" width="100%" >
</p>

Comparison with fully- and semi-supervised state-of-the-art models on LA and LiTS test set. Due to GPU memory limitations, some semi-supervised models using smaller architectures, ✝ and * indicate models are based on lightweight 3D UNet (half of channels) and VNet, respectively. ‡ indicates training for 1000 epochs. - indicates training failed. <font color="Red">**Red**</font> and **bold** indicate the best and second best performance.

<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Comparison%20results%20on%20LA%20and%20P-CT.png" width="100%" >
</p>

## Qualitative Comparison

<p align="center">
<img src="https://github.com/Yanfeng-Zhou/XNet/blob/main/figure/Qualitative%20results.png" width="100%" >
<br>Qualitative results on GIaS, CREMI, LA and LiTS. (a) Raw images. (b) Ground truth. (c) MT. (d) Semi-supervised XNet (3D XNet). (e) UNet (3D UNet). (f) Fully-Supervised XNet (3D XNet). The orange arrows highlight the difference among of the results.
</p>

## Reimplemented Architecture
We have reimplemented some 2D and 3D models in semi- and supervised semantic segmentation.
<table>
<tr><th align="left">Method</th> <th align="left">Dimension</th><th align="left">Model</th><th align="left">Code</th></tr>
<tr><td rowspan="23">Supervised</td> <td rowspan="13">2D</td><td>UNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet.py">models/networks_2d/unet.py</a></td></tr>
<tr><td>UNet++</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet_plusplus.py">models/networks_2d/unet_plusplus.py</a></td></tr>
<tr><td>Att-UNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet.py">models/networks_2d/unet.py</a></td></tr>
<tr><td>Aerial LaneNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/aerial_lanenet.py">models/networks_2d/aerial_lanenet.py</a></td></tr>
<tr><td>MWCNN</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/mwcnn.py">models/networks_2d/mwcnn.py</a></td></tr>
<tr><td>HRNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/hrnet.py">models/networks_2d/hrnet.py</a></td></tr>
<tr><td>Res-UNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/resunet.py">models/networks_2d/resunet.py</a></td></tr>
<tr><td>WDS</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/wds.py">models/networks_2d/wds.py</a></td></tr>
<tr><td>U<sup>2</sup>-Net</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/u2net.py">models/networks_2d/u2net.py</a></td></tr>
<tr><td>UNet 3+</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/unet_3plus.py">models/networks_2d/unet_3plus.py</a></td></tr>
<tr><td>SwinUNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/swinunet.py">models/networks_2d/swinunet.py</a></td></tr>
<tr><td>WaveSNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/wavesnet.py">models/networks_2d/wavesnet.py</a></td></tr>
<tr><td>XNet (Ours)</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_2d/xnet.py">models/networks_2d/xnet.py</a></td></tr>
<tr><td rowspan="10">3D</td><td>VNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/vnet.py">models/networks_3d/vnet.py</a></td></tr>
<tr><td>UNet 3D</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/unet3d.py">models/networks_3d/unet3d.py</a></td></tr>
<tr><td>Res-UNet 3D</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/res_unet3d.py">models/networks_3d/res_unet3d.py</a></td></tr>
<tr><td>ESPNet 3D</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/espnet3d.py">models/networks_3d/espnet3d.py</a></td></tr>
<tr><td>DMFNet 3D</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/dmfnet.py">models/networks_3d/dmfnet.py</a></td></tr>
<tr><td>ConResNet</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/conresnet.py">models/networks_3d/conresnet.py</a></td></tr>
<tr><td>CoTr</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/cotr.py">models/networks_3d/cotr.py</a></td></tr>
<tr><td>TransBTS</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/transbts.py">models/networks_3d/transbts.py</a></td></tr>
<tr><td>UNETR</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/unetr.py">models/networks_3d/unetr.py</a></td></tr>
<tr><td>XNet 3D (Ours)</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/models/networks_3d/xnet3d.py">models/networks_3d/xnet3d.py</a></td></tr>
<tr><td rowspan="17">Semi-Supervised</td> <td rowspan="8">2D</td><td>MT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_MT.py">train_semi_MT.py</a></td></tr>
<tr><td>EM</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_EM.py">train_semi_EM.py</a></td></tr>
<tr><td>UAMT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_UAMT.py">train_semi_UAMT.py</a></td></tr>
<tr><td>CCT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CCT.py">train_semi_CCT.py</a></td></tr>
<tr><td>CPS</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CPS.py">train_semi_CPS.py</a></td></tr>
<tr><td>URPC</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_URPC.py">train_semi_URPC.py</a></td></tr>
<tr><td>CT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CT.py">train_semi_CT.py</a></td></tr>
<tr><td>XNet (Ours)</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_XNet.py">train_semi_XNet.py</a></td></tr>
<td rowspan="9">3D</td><td>MT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_MT_3d.py">train_semi_MT_3d.py</a></td></tr>
<tr><td>EM</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_EM_3d.py">train_semi_EM_3d.py</a></td></tr>
<tr><td>UAMT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_UAMT_3d.py">train_semi_UAMT_3d.py</a></td></tr>
<tr><td>CCT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CCT_3d.py">train_semi_CCT_3d.py</a></td></tr>
<tr><td>CPS</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CPS_3d.py">train_semi_CPS_3d.py</a></td></tr>
<tr><td>URPC</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_URPC_3d.py">train_semi_URPC_3d.py</a></td></tr>
<tr><td>CT</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_CT_3d.py">train_semi_CT_3d.py</a></td></tr>
<tr><td>DTC</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_DTC.py">train_semi_DTC.py</a></td></tr>
<tr><td>XNet 3D (Ours)</td><td><a href="https://github.com/Yanfeng-Zhou/XNet/blob/main/train_semi_XNet3d.py">train_semi_XNet3d.py</a></td></tr>
</table>

## Requirements
```
albumentations==0.5.2
einops==0.4.1
MedPy==0.4.0
numpy==1.20.2
opencv_python==4.2.0.34
opencv_python_headless==4.5.1.48
Pillow==8.0.0
PyWavelets==1.1.1
scikit_image==0.18.1
scikit_learn==1.0.1
scipy==1.4.1
SimpleITK==2.1.0
timm==0.6.7
torch==1.8.0+cu111
torchio==0.18.53
torchvision==0.9.0+cu111
tqdm==4.65.0
visdom==0.1.8.9
```

## Usage
**Data preparation**
Your datasets directory tree should be look like this:
>to see [tools/wavelet2D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet2D.py) and  [tools/wavelet3D.py](https://github.com/Yanfeng-Zhou/XNet/blob/main/tools/wavelet3D.py) for **L** and **H**
```
dataset
├── train_sup_100
    ├── L
        ├── 1.tif
        ├── 2.tif
        └── ...
    ├── H
        ├── 1.tif
        ├── 2.tif
        └── ...
    └── mask
        ├── 1.tif
        ├── 2.tif
        └── ...
├── train_sup_20
    ├── L
    ├── H
    └── mask
├── train_unsup_80
    └── L
    ├── H
└── val
    ├── L
    ├── H
    └── mask
```
**Supervised training**
```
python -m torch.distributed.launch --nproc_per_node=4 train_sup_XNet.py
```
**Semi-supervised training**
```
python -m torch.distributed.launch --nproc_per_node=4 train_semi_XNet.py
```
**Testing**
```
python -m torch.distributed.launch --nproc_per_node=4 test.py
```

## Citation
If our work is useful for your research, please cite our paper:
```
@InProceedings{Zhou_2023_ICCV,
  author = {Zhou, Yanfeng and Huang, Jiaxing and Wang, Chenlong and Song, Le and Yang, Ge}, 
  title = {XNet: Wavelet-Based Low and High Frequency Fusion Networks for Fully- and Semi-Supervised Semantic Segmentation of Biomedical Images}, 
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 
  month = {October}, 
  year = {2023}, 
  pages = {21085-21096}
  }
```





================================================
FILE: config/__init__.py
================================================


================================================
FILE: config/augmentation/__init__.py
================================================


================================================
FILE: config/augmentation/online_aug.py
================================================
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchio import transforms as T
import torchio as tio

def data_transform_2d():
    data_transforms = {
        'train': A.Compose([
            A.Resize(128, 128, p=1),
            A.Flip(p=0.75),
            A.Transpose(p=0.5),
            A.RandomRotate90(p=1),
        ],
            additional_targets={'image2': 'image', 'mask2': 'mask'}
        ),
        'val': A.Compose([
            A.Resize(128, 128, p=1),
        ],
            additional_targets={'image2': 'image', 'mask2': 'mask'}
        ),
        'test': A.Compose([
            A.Resize(128, 128, p=1),
        ],
            additional_targets={'image2': 'image', 'mask2': 'mask'}
        )
    }
    return data_transforms


def data_normalize_2d(mean, std):
    data_normalize = A.Compose([
            A.Normalize(mean, std),
            ToTensorV2()
        ],
            additional_targets={'image2': 'image', 'mask2': 'mask'}
    )
    return data_normalize

def data_transform_aerial_lanenet(H, W):
    data_transforms = A.Compose([
            A.Resize(H, W, p=1),
            ToTensorV2()
        ])
    return data_transforms


def data_transform_3d(normalization):
    data_transform = {
        'train': T.Compose([
            T.RandomFlip(),
            T.RandomBiasField(coefficients=(0.12, 0.15), order=2, p=0.2),
            T.OneOf({
               T.RandomNoise(): 0.5,
               T.RandomBlur(std=1): 0.5,
            }, p=0.2),
            T.ZNormalization(masking_method=normalization),
        ]),
        'val': T.Compose([
            # T.CropOrPad(pad_size),
            T.ZNormalization(masking_method=normalization),
            # T.Resize(target_shape=(512, 512, 512), p=1)
        ]),
        'test': T.Compose([
            # T.CropOrPad(pad_size),
            T.ZNormalization(masking_method=normalization),
            # T.Resize(target_shape=(512, 512, 512), p=1)
        ])
    }

    return data_transform

================================================
FILE: config/dataset_config/__init__.py
================================================


================================================
FILE: config/dataset_config/dataset_cfg.py
================================================
import numpy as np
import torchio as tio

def dataset_cfg(dataet_name):

    config = {
        'CREMI':
            {
                'IN_CHANNELS': 1,
                'NUM_CLASSES': 2,
                'MEAN': [0.503902],
                'STD': [0.110739],
                'MEAN_DB2_H': [0.505787],
                'STD_DB2_H': [0.115504],
                'PALETTE': list(np.array([
                    [255, 255, 255],
                    [0, 0, 0],
                ]).flatten())
            },
        'GlaS':
            {
                'IN_CHANNELS': 3,
                'NUM_CLASSES': 2,
                'MEAN': [0.787803, 0.512017, 0.784938],
                'STD': [0.428206, 0.507778, 0.426366],
                'MEAN_HAAR_H': [0.528318],
                'STD_HAAR_H': [0.076766],
                'MEAN_HAAR_L': [0.579144],
                'STD_HAAR_L': [0.227451],
                'MEAN_HAAR_HHL': [0.542428],
                'STD_HAAR_HHL': [0.142663],
                'MEAN_HAAR_HLL': [0.569150],
                'STD_HAAR_HLL': [0.220854],
                'MEAN_BIOR1.5_H': [0.525711],
                'STD_BIOR1.5_H': [0.076606],
                'MEAN_BIOR2.4_H': [0.516579],
                'STD_BIOR2.4_H': [0.078798],
                'MEAN_COIF1_H': [0.523858],
                'STD_COIF1_H': [0.081001],
                'MEAN_DB2_H': [0.505234],
                'STD_DB2_H': [0.080919],
                'MEAN_DMEY_H': [0.502698],
                'STD_DMEY_H': [0.078861],
                'PALETTE': list(np.array([
                    [0, 0, 0],
                    [255, 255, 255],
                ]).flatten())
            },
        'ISIC-2017':
            {
                'IN_CHANNELS': 3,
                'NUM_CLASSES': 2,
                'MEAN': [0.699002, 0.556046, 0.512134],
                'STD': [0.365650, 0.317347, 0.339400],
                'MEAN_DB2_H': [0.489676],
                'STD_DB2_H': [0.081749],
                'PALETTE': list(np.array([
                    [0, 0, 0],
                    [255, 255, 255],
                ]).flatten())
            },
        'LiTS':
            {
                'IN_CHANNELS': 1,
                'NUM_CLASSES': 3,
                'NORMALIZE': tio.ZNormalization.mean,
                'PATCH_SIZE': (112, 112, 32),
                'FORMAT': '.nii',
                'NUM_SAMPLE_TRAIN': 8,
                'NUM_SAMPLE_VAL': 12
            },
        'Atrial':
            {
                'IN_CHANNELS': 1,
                'NUM_CLASSES': 2,
                'NORMALIZE': tio.ZNormalization.mean,
                'PATCH_SIZE': (96, 96, 80),
                'FORMAT': '.nrrd',
                'NUM_SAMPLE_TRAIN': 4,
                'NUM_SAMPLE_VAL': 8
            },
    }

    return config[dataet_name]


================================================
FILE: config/eval_config/__init__.py
================================================


================================================
FILE: config/eval_config/eval.py
================================================
import numpy as np
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff
import torch


def evaluate(y_scores, y_true, interval=0.02):

    y_scores = torch.softmax(y_scores, dim=1)
    y_scores = y_scores[:, 1, ...].cpu().detach().numpy().flatten()
    y_true = y_true.data.cpu().numpy().flatten()

    thresholds = np.arange(0, 0.9, interval)
    jaccard = np.zeros(len(thresholds))
    dice = np.zeros(len(thresholds))
    y_true.astype(np.int8)

    for indy in range(len(thresholds)):
        threshold = thresholds[indy]
        y_pred = (y_scores > threshold).astype(np.int8)

        sum_area = (y_pred + y_true)
        tp = float(np.sum(sum_area == 2))
        union = np.sum(sum_area == 1)
        jaccard[indy] = tp / float(union + tp)
        dice[indy] = 2 * tp / float(union + 2 * tp)

    thred_indx = np.argmax(jaccard)
    m_jaccard = jaccard[thred_indx]
    m_dice = dice[thred_indx]

    return thresholds[thred_indx], m_jaccard, m_dice



def evaluate_multi(y_scores, y_true):

    y_scores = torch.softmax(y_scores, dim=1)
    y_pred = torch.max(y_scores, 1)[1]
    y_pred = y_pred.data.cpu().numpy().flatten()
    y_true = y_true.data.cpu().numpy().flatten()

    hist = confusion_matrix(y_true, y_pred)

    hist_diag = np.diag(hist)
    hist_sum_0 = hist.sum(axis=0)
    hist_sum_1 = hist.sum(axis=1)

    jaccard = hist_diag / (hist_sum_1 + hist_sum_0 - hist_diag)
    m_jaccard = np.nanmean(jaccard)
    dice = 2 * hist_diag / (hist_sum_1 + hist_sum_0)
    m_dice = np.nanmean(dice)

    return jaccard, m_jaccard, dice, m_dice






================================================
FILE: config/ramps/__init__.py
================================================


================================================
FILE: config/ramps/ramps.py
================================================
import numpy as np


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length


def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    assert 0 <= current <= rampdown_length
    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))


================================================
FILE: config/train_test_config/__init__.py
================================================


================================================
FILE: config/train_test_config/train_test_config.py
================================================
import numpy as np
from config.eval_config.eval import evaluate, evaluate_multi
import torch
import os
from PIL import Image
import torchio as tio

def print_train_loss_sup(train_loss, num_batches, print_num, print_num_minus):
    train_epoch_loss = train_loss / num_batches['train_sup']
    print('-' * print_num)
    print('| Train Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
    print('-' * print_num)
    return train_epoch_loss

def print_train_loss_MT(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_half, print_num_minus):
    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
    train_epoch_loss = train_loss / num_batches['train_sup']
    print('-' * print_num)
    print('| Train  Sup  Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '|')
    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
    print('-' * print_num)
    return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss

def print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_loss, num_batches, print_num, print_num_half, print_num_minus):
    train_epoch_loss_seg = train_loss_seg / num_batches['train_sup']
    train_epoch_loss_res = train_loss_res / num_batches['train_sup']
    train_epoch_loss = train_loss / num_batches['train_sup']
    print('-' * print_num)
    print('| Train  Seg  Loss: {:.4f}'.format(train_epoch_loss_seg).ljust(print_num_half, ' '), '| Train Res Loss: {:.4f}'.format(train_epoch_loss_res).ljust(print_num_half, ' '), '|')
    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
    print('-' * print_num)
    return train_epoch_loss_seg, train_epoch_loss_res, train_epoch_loss


def print_train_loss_EM(train_loss_sup_1, train_loss_cps, train_loss, num_batches, print_num, print_num_minus):
    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
    train_epoch_loss = train_loss / num_batches['train_sup']
    print('-' * print_num)
    print('| Train  Sup  Loss: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_minus, ' '), '|')
    print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_minus, ' '), '|')
    print('| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_minus, ' '), '|')
    print('-' * print_num)
    return train_epoch_loss_sup1, train_epoch_loss_cps, train_epoch_loss


def print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_cps, train_loss, num_batches, print_num, print_num_half):
    train_epoch_loss_sup1 = train_loss_sup_1 / num_batches['train_sup']
    train_epoch_loss_sup2 = train_loss_sup_2 / num_batches['train_sup']
    train_epoch_loss_cps = train_loss_cps / num_batches['train_sup']
    train_epoch_loss = train_loss / num_batches['train_sup']
    print('-' * print_num)
    print('| Train Sup Loss 1: {:.4f}'.format(train_epoch_loss_sup1).ljust(print_num_half, ' '), '| Train SUP Loss 2: {:.4f}'.format(train_epoch_loss_sup2).ljust(print_num_half, ' '), '|')
    print('| Train Unsup Loss: {:.4f}'.format(train_epoch_loss_cps).ljust(print_num_half, ' '), '| Train Total Loss: {:.4f}'.format(train_epoch_loss).ljust(print_num_half, ' '), '|')
    print('-' * print_num)
    return train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss

def print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus):
    val_epoch_loss = val_loss / num_batches['val']
    print('-' * print_num)
    print('| Val Loss: {:.4f}'.format(val_epoch_loss).ljust(print_num_minus, ' '), '|')
    print('-' * print_num)
    return val_epoch_loss

def print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_num, print_num_half):
    val_epoch_loss_sup1 = val_loss_sup_1 / num_batches['val']
    val_epoch_loss_sup2 = val_loss_sup_2 / num_batches['val']
    print('-' * print_num)
    print('| Val Sup Loss 1: {:.4f}'.format(val_epoch_loss_sup1).ljust(print_num_half, ' '), '| Val Sup Loss 2: {:.4f}'.format(val_epoch_loss_sup2).ljust(print_num_half, ' '), '|')
    print('-' * print_num)
    return val_epoch_loss_sup1, val_epoch_loss_sup2

def print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, print_num, print_num_half):
    val_epoch_loss_seg = val_loss_seg / num_batches['val']
    val_epoch_loss_res = val_loss_res / num_batches['val']
    print('-' * print_num)
    print('| Val Seg Loss: {:.4f}'.format(val_epoch_loss_seg).ljust(print_num_half, ' '), '| Val Res Loss: {:.4f}'.format(val_epoch_loss_res).ljust(print_num_half, ' '), '|')
    print('-' * print_num)
    return val_epoch_loss_seg, val_epoch_loss_res

def print_train_eval_sup(num_classes, score_list_train, mask_list_train, print_num):

    if num_classes == 2:
        eval_list = evaluate(score_list_train, mask_list_train)
        print('| Train Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Train  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Train  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
        train_m_jc = eval_list[1]

    else:
        eval_list = evaluate_multi(score_list_train, mask_list_train)

        np.set_printoptions(precision=4, suppress=True)
        print('| Train  Jc: {}'.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Train  Dc: {}'.format(eval_list[2]).ljust(print_num, ' '), '|')
        print('| Train mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Train mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')
        train_m_jc = eval_list[1]

    return eval_list, train_m_jc

def print_train_eval_XNet(num_classes, score_list_train1, score_list_train2, mask_list_train, print_num):

    if num_classes == 2:
        eval_list1 = evaluate(score_list_train1, mask_list_train)
        eval_list2 = evaluate(score_list_train2, mask_list_train)
        print('| Train Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
        print('| Train  Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train  Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
        print('| Train  Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train  Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
        train_m_jc1 = eval_list1[1]
        train_m_jc2 = eval_list2[1]
    else:
        eval_list1 = evaluate_multi(score_list_train1, mask_list_train)
        eval_list2 = evaluate_multi(score_list_train2, mask_list_train)
        np.set_printoptions(precision=4, suppress=True)
        print('| Train  Jc 1: {}'.format(eval_list1[0]).ljust(print_num, ' '), '| Train  Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
        print('| Train  Dc 1: {}'.format(eval_list1[2]).ljust(print_num, ' '), '| Train  Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
        print('| Train mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Train mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
        print('| Train mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Train mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')
        train_m_jc1 = eval_list1[1]
        train_m_jc2 = eval_list2[1]

    return eval_list1, eval_list2, train_m_jc1, train_m_jc2

def print_val_eval_sup(num_classes, score_list_val, mask_list_val, print_num):
    if num_classes == 2:
        eval_list = evaluate(score_list_val, mask_list_val)
        print('| Val Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Val  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Val  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
        val_m_jc = eval_list[1]
    else:
        eval_list = evaluate_multi(score_list_val, mask_list_val)
        np.set_printoptions(precision=4, suppress=True)
        print('| Val  Jc: {}  '.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Val  Dc: {}  '.format(eval_list[2]).ljust(print_num, ' '), '|')
        print('| Val mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Val mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')
        val_m_jc = eval_list[1]
    return eval_list, val_m_jc

def print_val_eval(num_classes, score_list_val1, score_list_val2, mask_list_val, print_num):
    if num_classes == 2:
        eval_list1 = evaluate(score_list_val1, mask_list_val)
        eval_list2 = evaluate(score_list_val2, mask_list_val)
        print('| Val Thr 1: {:.4f}'.format(eval_list1[0]).ljust(print_num, ' '), '| Val Thr 2: {:.4f}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
        print('| Val  Jc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val  Jc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
        print('| Val  Dc 1: {:.4f}'.format(eval_list1[2]).ljust(print_num, ' '), '| Val  Dc 2: {:.4f}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
        val_m_jc1 = eval_list1[1]
        val_m_jc2 = eval_list2[1]
    else:
        eval_list1 = evaluate_multi(score_list_val1, mask_list_val)
        eval_list2 = evaluate_multi(score_list_val2, mask_list_val)
        np.set_printoptions(precision=4, suppress=True)
        print('| Val  Jc 1: {}  '.format(eval_list1[0]).ljust(print_num, ' '), '| Val  Jc 2: {}'.format(eval_list2[0]).ljust(print_num, ' '), '|')
        print('| Val  Dc 1: {}  '.format(eval_list1[2]).ljust(print_num, ' '), '| Val  Dc 2: {}'.format(eval_list2[2]).ljust(print_num, ' '), '|')
        print('| Val mJc 1: {:.4f}'.format(eval_list1[1]).ljust(print_num, ' '), '| Val mJc 2: {:.4f}'.format(eval_list2[1]).ljust(print_num, ' '), '|')
        print('| Val mDc 1: {:.4f}'.format(eval_list1[3]).ljust(print_num, ' '), '| Val mDc 2: {:.4f}'.format(eval_list2[3]).ljust(print_num, ' '), '|')
        val_m_jc1 = eval_list1[1]
        val_m_jc2 = eval_list2[1]
    return eval_list1, eval_list2, val_m_jc1, val_m_jc2

def save_val_best_sup_2d(num_classes, best_list, model, score_list_val, name_list_val, eval_list, path_trained_model, path_seg_results, palette, model_name):

    if num_classes == 2:
        if best_list[1] < eval_list[1]:
            best_list = eval_list

            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))

            score_list_val = torch.softmax(score_list_val, dim=1)
            pred_results = score_list_val[:, 1, :, :].cpu().numpy()
            pred_results[pred_results > eval_list[0]] = 1
            pred_results[pred_results <= eval_list[0]] = 0

            assert len(name_list_val) == pred_results.shape[0]
            for i in range(len(name_list_val)):
                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
                color_results.putpalette(palette)
                color_results.save(os.path.join(path_seg_results, name_list_val[i]))

    else:
        if best_list[1] < eval_list[1]:
            best_list = eval_list

            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))

            pred_results = torch.max(score_list_val, 1)[1]
            pred_results = pred_results.cpu().numpy()

            assert len(name_list_val) == pred_results.shape[0]
            for i in range(len(name_list_val)):
                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
                color_results.putpalette(palette)
                color_results.save(os.path.join(path_seg_results, name_list_val[i]))

    return best_list

def save_val_best_sup_3d(num_classes, best_list, model, score_list_val, mask_list_val, eval_list, path_trained_model, path_seg_results, path_mask_results, model_name, format):

    if num_classes == 2:
        if best_list[1] < eval_list[1]:
            best_list = eval_list

            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))

    else:
        if best_list[1] < eval_list[1]:
            best_list = eval_list

            torch.save(model.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format(model_name, best_list[1])))

    return best_list

def save_val_best_2d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, name_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, palette):

    if eval_list_1[1] < eval_list_2[1]:
        if best_list[1] < eval_list_2[1]:

            best_model = model2
            best_list = eval_list_2
            best_result = 'Result2'

            torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))

            if num_classes == 2:
                score_list_val_2 = torch.softmax(score_list_val_2, dim=1)
                pred_results = score_list_val_2[:, 1, ...].cpu().numpy()
                pred_results[pred_results > eval_list_2[0]] = 1
                pred_results[pred_results <= eval_list_2[0]] = 0
            else:
                pred_results = torch.max(score_list_val_2, 1)[1]
                pred_results = pred_results.cpu().numpy()

            assert len(name_list_val) == pred_results.shape[0]
            for i in range(len(name_list_val)):
                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
                color_results.putpalette(palette)
                color_results.save(os.path.join(path_seg_results, name_list_val[i]))
        else:
            best_model = best_model
            best_list = best_list
            best_result = best_result

    else:
        if best_list[1] < eval_list_1[1]:

            best_model = model1
            best_list = eval_list_1
            best_result = 'Result1'

            torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))

            if num_classes == 2:
                score_list_val_1 = torch.softmax(score_list_val_1, dim=1)
                pred_results = score_list_val_1[:, 1, ...].cpu().numpy()
                pred_results[pred_results > eval_list_1[0]] = 1
                pred_results[pred_results <= eval_list_1[0]] = 0
            else:
                pred_results = torch.max(score_list_val_1, 1)[1]
                pred_results = pred_results.cpu().numpy()

            assert len(name_list_val) == pred_results.shape[0]
            for i in range(len(name_list_val)):
                color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
                color_results.putpalette(palette)
                color_results.save(os.path.join(path_seg_results, name_list_val[i]))
        else:
            best_model = best_model
            best_list = best_list
            best_result = best_result


    return best_list, best_model, best_result


def save_val_best_3d(num_classes, best_model, best_list, best_result, model1, model2, score_list_val_1, score_list_val_2, mask_list_val, eval_list_1, eval_list_2, path_trained_model, path_seg_results, path_mask_results, format):

    if eval_list_1[1] < eval_list_2[1]:
        if best_list[1] < eval_list_2[1]:

            best_model = model2
            best_list = eval_list_2
            best_result = 'Result2'

            torch.save(model2.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result2', best_list[1])))

        else:
            best_model = best_model
            best_list = best_list
            best_result = best_result

    else:
        if best_list[1] < eval_list_1[1]:

            best_model = model1
            best_list = eval_list_1
            best_result = 'Result1'

            torch.save(model1.state_dict(), os.path.join(path_trained_model, 'best_{}_Jc_{:.4f}.pth'.format('result1', best_list[1])))

        else:
            best_model = best_model
            best_list = best_list
            best_result = best_result

    return best_list, best_model, best_result

def draw_pred_sup(num_classes, mask_train_sup, mask_val, pred_train_sup, outputs_val, train_eval_list, val_eval_list):


    mask_image_train_sup = mask_train_sup[0, :, :].data.cpu().numpy()
    mask_image_val = mask_val[0, :, :].data.cpu().numpy()

    if num_classes == 2:
        pred_image_train_sup = pred_train_sup[0, 1, :, :].data.cpu().numpy()
        pred_image_train_sup[pred_image_train_sup > train_eval_list[0]] = 1
        pred_image_train_sup[pred_image_train_sup <= train_eval_list[0]] = 0

        pred_image_val = outputs_val[0, 1, :, :].data.cpu().numpy()
        pred_image_val[pred_image_val > val_eval_list[0]] = 1
        pred_image_val[pred_image_val <= val_eval_list[0]] = 0

    else:
        pred_image_train_sup = torch.max(pred_train_sup, 1)[1]
        pred_image_train_sup = pred_image_train_sup[0, :, :].cpu().numpy()

        pred_image_val = torch.max(outputs_val, 1)[1]
        pred_image_val = pred_image_val[0, :, :].cpu().numpy()

    return mask_image_train_sup, pred_image_train_sup, mask_image_val, pred_image_val


def draw_pred_XNet(num_classes, mask_train, mask_val, pred_train_sup1, pred_train_sup2, outputs_val1, outputs_val2, train_eval_list1, train_eval_list2, val_eval_list1, val_eval_list2):


    mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()
    mask_image_val = mask_val[0, :, :].data.cpu().numpy()

    if num_classes == 2:

        pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()
        pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1
        pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0

        pred_image_train_sup2 = pred_train_sup2[0, 1, :, :].data.cpu().numpy()
        pred_image_train_sup2[pred_image_train_sup2 > train_eval_list2[0]] = 1
        pred_image_train_sup2[pred_image_train_sup2 <= train_eval_list2[0]] = 0

        pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()
        pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1
        pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0

        pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()
        pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1
        pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0
    else:

        pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]
        pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()

        pred_image_train_sup2 = torch.max(pred_train_sup2, 1)[1]
        pred_image_train_sup2 = pred_image_train_sup2[0, :, :].cpu().numpy()

        pred_image_val1 = torch.max(outputs_val1, 1)[1]
        pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()

        pred_image_val2 = torch.max(outputs_val2, 1)[1]
        pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()

    return mask_image_train_sup, pred_image_train_sup1, pred_image_train_sup2, mask_image_val, pred_image_val1, pred_image_val2

def draw_pred_MT(num_classes, mask_train, mask_val, pred_train_sup1, outputs_val1, outputs_val2, train_eval_list1, val_eval_list1, val_eval_list2):


    mask_image_train_sup = mask_train[0, :, :].data.cpu().numpy()
    mask_image_val = mask_val[0, :, :].data.cpu().numpy()

    if num_classes == 2:

        pred_image_train_sup1 = pred_train_sup1[0, 1, :, :].data.cpu().numpy()
        pred_image_train_sup1[pred_image_train_sup1 > train_eval_list1[0]] = 1
        pred_image_train_sup1[pred_image_train_sup1 <= train_eval_list1[0]] = 0

        pred_image_val1 = outputs_val1[0, 1, :, :].data.cpu().numpy()
        pred_image_val1[pred_image_val1 > val_eval_list1[0]] = 1
        pred_image_val1[pred_image_val1 <= val_eval_list1[0]] = 0

        pred_image_val2 = outputs_val2[0, 1, :, :].data.cpu().numpy()
        pred_image_val2[pred_image_val2 > val_eval_list2[0]] = 1
        pred_image_val2[pred_image_val2 <= val_eval_list2[0]] = 0
    else:

        pred_image_train_sup1 = torch.max(pred_train_sup1, 1)[1]
        pred_image_train_sup1 = pred_image_train_sup1[0, :, :].cpu().numpy()

        pred_image_val1 = torch.max(outputs_val1, 1)[1]
        pred_image_val1 = pred_image_val1[0, :, :].cpu().numpy()

        pred_image_val2 = torch.max(outputs_val2, 1)[1]
        pred_image_val2 = pred_image_val2[0, :, :].cpu().numpy()

    return mask_image_train_sup, pred_image_train_sup1, mask_image_val, pred_image_val1, pred_image_val2


def print_best_sup(num_classes, best_val_list, print_num):
    if num_classes == 2:
        print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
        print('| Best Val  Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
        print('| Best Val  Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
    else:
        np.set_printoptions(precision=4, suppress=True)
        print('| Best Val  Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
        print('| Best Val  Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
        print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
        print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')

def print_best(num_classes, best_val_list, best_model, best_result, path_trained_model, print_num):
    if num_classes == 2:

        torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))

        print('| Best  Result: {}'.format(best_result).ljust(print_num, ' '), '|')
        print('| Best Val Thr: {:.4f}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
        print('| Best Val  Jc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
        print('| Best Val  Dc: {:.4f}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
    else:

        torch.save(best_model.state_dict(), os.path.join(path_trained_model, 'best_Jc_{:.4f}.pth'.format(best_val_list[1])))

        np.set_printoptions(precision=4, suppress=True)
        print('| Best  Result: {}'.format(best_result).ljust(print_num, ' '), '|')
        print('| Best Val  Jc: {}'.format(best_val_list[0]).ljust(print_num, ' '), '|')
        print('| Best Val  Dc: {}'.format(best_val_list[2]).ljust(print_num, ' '), '|')
        print('| Best Val mJc: {:.4f}'.format(best_val_list[1]).ljust(print_num, ' '), '|')
        print('| Best Val mDc: {:.4f}'.format(best_val_list[3]).ljust(print_num, ' '), '|')

def print_test_eval(num_classes, score_list_test, mask_list_test, print_num):
    if num_classes == 2:
        eval_list = evaluate(score_list_test, mask_list_test)
        print('| Test Thr: {:.4f}'.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Test  Jc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Test  Dc: {:.4f}'.format(eval_list[2]).ljust(print_num, ' '), '|')
    else:
        eval_list = evaluate_multi(score_list_test, mask_list_test)
        np.set_printoptions(precision=4, suppress=True)
        print('| Test  Jc: {}  '.format(eval_list[0]).ljust(print_num, ' '), '|')
        print('| Test  Dc: {}  '.format(eval_list[2]).ljust(print_num, ' '), '|')
        print('| Test mJc: {:.4f}'.format(eval_list[1]).ljust(print_num, ' '), '|')
        print('| Test mDc: {:.4f}'.format(eval_list[3]).ljust(print_num, ' '), '|')

    return eval_list


def save_test_2d(num_classes, score_list_test, name_list_test, threshold, path_seg_results, palette):

    if num_classes == 2:
        score_list_test = torch.softmax(score_list_test, dim=1)
        pred_results = score_list_test[:, 1, ...].cpu().numpy()
        pred_results[pred_results > threshold] = 1
        pred_results[pred_results <= threshold] = 0

        assert len(name_list_test) == pred_results.shape[0]

        for i in range(len(name_list_test)):
            color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
            color_results.putpalette(palette)
            color_results.save(os.path.join(path_seg_results, name_list_test[i]))

    else:
        pred_results = torch.max(score_list_test, 1)[1]
        pred_results = pred_results.cpu().numpy()

        assert len(name_list_test) == pred_results.shape[0]

        for i in range(len(name_list_test)):
            color_results = Image.fromarray(pred_results[i].astype(np.uint8), mode='P')
            color_results.putpalette(palette)
            color_results.save(os.path.join(path_seg_results, name_list_test[i]))

def save_test_3d(num_classes, score_test, name_test, threshold, path_seg_results, affine):

    if num_classes == 2:
        score_list_test = torch.softmax(score_test, dim=0)
        pred_results = score_list_test[1, ...].cpu()
        pred_results[pred_results > threshold] = 1
        pred_results[pred_results <= threshold] = 0

        pred_results = pred_results.type(torch.uint8)

        output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)
        output_image.save(os.path.join(path_seg_results, name_test))

    else:
        pred_results = torch.max(score_test, 0)[1]
        pred_results = pred_results.cpu()
        pred_results = pred_results.type(torch.uint8)

        output_image = tio.ScalarImage(tensor=pred_results.unsqueeze(0), affine=affine)
        output_image.save(os.path.join(path_seg_results, name_test))





================================================
FILE: config/visdom_config/__init__.py
================================================


================================================
FILE: config/visdom_config/visual_visdom.py
================================================
from visdom import Visdom
import os

def visdom_initialization_sup(env, port):
    visdom = Visdom(env=env, port=port)
    visdom.line([0.], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss'], width=550, height=350))
    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
    visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Loss'], width=550, height=350))
    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
    return visdom

def visualization_sup(vis, epoch, train_loss, train_m_jc, val_loss, val_m_jc):
    vis.line([train_loss], [epoch], win='train_loss', update='append')
    vis.line([train_m_jc], [epoch], win='train_jc', update='append')
    vis.line([val_loss], [epoch], win='val_loss', update='append')
    vis.line([val_m_jc], [epoch], win='val_jc', update='append')

def visual_image_sup(vis, mask_train, pred_train, mask_val, pred_val):

    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
    vis.heatmap(pred_train, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))
    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
    vis.heatmap(pred_val, win='val_pred1', opts=dict(title='Val Pred', colormap='Viridis'))


def visdom_initialization_XNet(env, port):
    visdom = Visdom(env=env, port=port)
    visdom.line([[0., 0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup1', 'Train Sup2', 'Train Unsup'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc1', 'Train Jc2'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))
    return visdom

def visualization_XNet(vis, epoch, train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps, train_m_jc1, train_m_jc2, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):
    vis.line([[train_loss, train_loss_sup1, train_loss_sup2, train_loss_cps]], [epoch], win='train_loss', update='append')
    vis.line([[train_m_jc1, train_m_jc2]], [epoch], win='train_jc', update='append')
    vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')
    vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')

def visual_image_XNet(vis, mask_train, pred_train1, pred_train2, mask_val, pred_val1, pred_val2):

    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
    vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred1', colormap='Viridis'))
    vis.heatmap(pred_train2, win='train_pred2', opts=dict(title='Train pred2', colormap='Viridis'))

    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
    vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))
    vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))


def visdom_initialization_MT(env, port):
    visdom = Visdom(env=env, port=port)
    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))
    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup1', 'Val Sup2'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc1', 'Val Jc2'], width=550, height=350))
    return visdom

def visualization_MT(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_loss_sup2, val_m_jc1, val_m_jc2):
    vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')
    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
    vis.line([[val_loss_sup1, val_loss_sup2]], [epoch], win='val_loss', update='append')
    vis.line([[val_m_jc1, val_m_jc2]], [epoch], win='val_jc', update='append')

def visual_image_MT(vis, mask_train, pred_train1, mask_val, pred_val1, pred_val2):

    vis.heatmap(mask_train, win='train_mask', opts=dict(title='Train Mask', colormap='Viridis'))
    vis.heatmap(pred_train1, win='train_pred1', opts=dict(title='Train Pred', colormap='Viridis'))
    vis.heatmap(mask_val, win='val_mask', opts=dict(title='Val Mask', colormap='Viridis'))
    vis.heatmap(pred_val1, win='val_pred1', opts=dict(title='Val Pred1', colormap='Viridis'))
    vis.heatmap(pred_val2, win='val_pred2', opts=dict(title='Val Pred2', colormap='Viridis'))


def visdom_initialization_EM(env, port):
    visdom = Visdom(env=env, port=port)
    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Sup', 'Train Unsup'], width=550, height=350))
    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
    visdom.line([0.], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Sup'], width=550, height=350))
    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
    return visdom

def visualization_EM(vis, epoch, train_loss, train_loss_sup1, train_loss_cps, train_m_jc1, val_loss_sup1, val_m_jc1):
    vis.line([[train_loss, train_loss_sup1, train_loss_cps]], [epoch], win='train_loss', update='append')
    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
    vis.line([val_loss_sup1], [epoch], win='val_loss', update='append')
    vis.line([val_m_jc1], [epoch], win='val_jc', update='append')


def visdom_initialization_ConResNet(env, port):
    visdom = Visdom(env=env, port=port)
    visdom.line([[0., 0., 0.]], [0.], win='train_loss', opts=dict(title='Train Loss', xlabel='Epoch', ylabel='Train Loss', legend=['Train Loss', 'Train Seg', 'Train Res'], width=550, height=350))
    visdom.line([0.], [0.], win='train_jc', opts=dict(title='Train Jc', xlabel='Epoch', ylabel='Train Jc', legend=['Train Jc'], width=550, height=350))
    visdom.line([[0., 0.]], [0.], win='val_loss', opts=dict(title='Val Loss', xlabel='Epoch', ylabel='Val Loss', legend=['Val Seg', 'Val Res'], width=550, height=350))
    visdom.line([0.], [0.], win='val_jc', opts=dict(title='Val Jc', xlabel='Epoch', ylabel='Val Jc', legend=['Val Jc'], width=550, height=350))
    return visdom

def visualization_ConResNet(vis, epoch, train_loss, train_loss_seg, train_loss_res, train_m_jc1, val_loss_seg, val_loss_res, val_m_jc1):
    vis.line([[train_loss, train_loss_seg, train_loss_res]], [epoch], win='train_loss', update='append')
    vis.line([train_m_jc1], [epoch], win='train_jc', update='append')
    vis.line([[val_loss_seg, val_loss_res]], [epoch], win='val_loss', update='append')
    vis.line([val_m_jc1], [epoch], win='val_jc', update='append')

================================================
FILE: config/warmup_config/__init__.py
================================================


================================================
FILE: config/warmup_config/warmup.py
================================================
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau


class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

================================================
FILE: dataload/__init__.py
================================================


================================================
FILE: dataload/dataset_2d.py
================================================
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
import pywt

class dataset_itn(Dataset):
    def __init__(self, data_dir, input1, augmentation_1, normalize_1, sup=True, num_images=None, **kwargs):
        super(dataset_itn, self).__init__()

        img_paths_1 = []
        mask_paths = []

        image_dir_1 = data_dir + '/' + input1
        if sup:
            mask_dir = data_dir + '/mask'

        for image in os.listdir(image_dir_1):

            image_path_1 = os.path.join(image_dir_1, image)
            img_paths_1.append(image_path_1)

            if sup:
                mask_path = os.path.join(mask_dir, image)
                mask_paths.append(mask_path)

        if sup:
            assert len(img_paths_1) == len(mask_paths)

        if num_images is not None:
            len_img_paths = len(img_paths_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                img_paths_1 = img_paths_1[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                img_paths_1 = img_paths_1 * quotient
                img_paths_1 += [img_paths_1[i] for i in new_indices]

                if sup:
                    mask_paths = mask_paths * quotient
                    mask_paths += [mask_paths[i] for i in new_indices]

        self.img_paths_1 = img_paths_1
        self.mask_paths = mask_paths
        self.augmentation_1 = augmentation_1
        self.normalize_1 = normalize_1
        self.sup = sup
        self.kwargs = kwargs

    def __getitem__(self, index):

        img_path_1 = self.img_paths_1[index]
        img_1 = Image.open(img_path_1)
        img_1 = np.array(img_1)

        if self.sup:
            mask_path = self.mask_paths[index]
            mask = Image.open(mask_path)
            mask = np.array(mask)

            augment_1 = self.augmentation_1(image=img_1, mask=mask)
            img_1 = augment_1['image']
            mask_1 = augment_1['mask']

            normalize_1 = self.normalize_1(image=img_1, mask=mask_1)
            img_1 = normalize_1['image']
            mask_1 = normalize_1['mask']
            mask_1 = mask_1.long()

            sampel = {'image': img_1, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}

        else:
            augment_1 = self.augmentation_1(image=img_1)
            img_1 = augment_1['image']
            normalize_1 = self.normalize_1(image=img_1)
            img_1 = normalize_1['image']

            sampel = {'image': img_1, 'ID': os.path.split(img_path_1)[1]}

        return sampel

    def __len__(self):
        return len(self.img_paths_1)


def imagefloder_itn(data_dir, input1, data_transform_1, data_normalize_1, sup=True, num_images=None, **kwargs):
    dataset = dataset_itn(data_dir=data_dir,
                           input1=input1,
                           augmentation_1=data_transform_1,
                           normalize_1=data_normalize_1,
                           sup=sup,
                           num_images=num_images,
                           **kwargs
                           )
    return dataset


class dataset_iitnn(Dataset):
    def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True,
                 num_images=None, **kwargs):
        super(dataset_iitnn, self).__init__()

        img_paths_1 = []
        img_paths_2 = []
        mask_paths = []

        image_dir_1 = data_dir + '/' + input1
        image_dir_2 = data_dir + '/' + input2
        if sup:
            mask_dir = data_dir + '/mask'

        for image in os.listdir(image_dir_1):

            image_path_1 = os.path.join(image_dir_1, image)
            img_paths_1.append(image_path_1)

            image_path_2 = os.path.join(image_dir_2, image)
            img_paths_2.append(image_path_2)

            if sup:
                mask_path = os.path.join(mask_dir, image)
                mask_paths.append(mask_path)

        assert len(img_paths_1) == len(img_paths_2)
        if sup:
            assert len(img_paths_1) == len(mask_paths)

        if num_images is not None:
            len_img_paths = len(img_paths_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                img_paths_1 = img_paths_1[:num_images]
                img_paths_2 = img_paths_2[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                img_paths_1 = img_paths_1 * quotient
                img_paths_1 += [img_paths_1[i] for i in new_indices]
                img_paths_2 = img_paths_2 * quotient
                img_paths_2 += [img_paths_2[i] for i in new_indices]

                if sup:
                    mask_paths = mask_paths * quotient
                    mask_paths += [mask_paths[i] for i in new_indices]

        self.img_paths_1 = img_paths_1
        self.img_paths_2 = img_paths_2
        self.mask_paths = mask_paths
        self.augmentation_1 = augmentation1
        self.normalize_1 = normalize_1
        self.normalize_2 = normalize_2
        self.sup = sup
        self.kwargs = kwargs

    def __getitem__(self, index):

        img_path_1 = self.img_paths_1[index]
        img_1 = Image.open(img_path_1)
        img_1 = np.array(img_1)

        img_path_2 = self.img_paths_2[index]
        img_2 = Image.open(img_path_2)
        img_2 = np.array(img_2)

        if self.sup:
            mask_path = self.mask_paths[index]
            mask = Image.open(mask_path)
            mask = np.array(mask)

            augment_1 = self.augmentation_1(image=img_1, image2=img_2, mask=mask)
            img_1 = augment_1['image']
            img_2 = augment_1['image2']
            mask = augment_1['mask']

            normalize_1 = self.normalize_1(image=img_1, mask=mask)
            img_1 = normalize_1['image']
            mask = normalize_1['mask']
            mask = mask.long()

            normalize_2 = self.normalize_2(image=img_2)
            img_2 = normalize_2['image']

            sampel = {'image': img_1, 'image_2': img_2, 'mask': mask, 'ID': os.path.split(mask_path)[1]}

        else:
            augment_1 = self.augmentation_1(image=img_1, image2=img_2)
            img_1 = augment_1['image']
            img_2 = augment_1['image2']

            normalize_1 = self.normalize_1(image=img_1)
            img_1 = normalize_1['image']

            normalize_2 = self.normalize_2(image=img_2)
            img_2 = normalize_2['image']

            sampel = {'image': img_1, 'image_2': img_2, 'ID': os.path.split(img_path_1)[1]}

        return sampel

    def __len__(self):
        return len(self.img_paths_1)


def imagefloder_iitnn(data_dir, input1, input2, data_transform_1, data_normalize_1, data_normalize_2, sup=True, num_images=None, **kwargs):
    dataset = dataset_iitnn(data_dir=data_dir,
                           input1=input1,
                           input2=input2,
                           augmentation1=data_transform_1,
                           normalize_1=data_normalize_1,
                           normalize_2=data_normalize_2,
                           sup=sup,
                           num_images=num_images,
                           **kwargs
                           )
    return dataset

class dataset_wds(Dataset):
    def __init__(self, data_dir, augmentation1, normalize_LL, normalize_LH, normalize_HL, normalize_HH, **kwargs):
        super(dataset_wds, self).__init__()

        img_paths_LL = []
        img_paths_LH = []
        img_paths_HL = []
        img_paths_HH = []
        mask_paths = []
        image_dir_LL = data_dir + '/LL'
        image_dir_LH = data_dir + '/LH'
        image_dir_HL = data_dir + '/HL'
        image_dir_HH = data_dir + '/HH'
        mask_dir = data_dir + '/mask'

        for image in os.listdir(image_dir_LL):

            image_path_LL = os.path.join(image_dir_LL, image)
            img_paths_LL.append(image_path_LL)
            image_path_LH = os.path.join(image_dir_LH, image)
            img_paths_LH.append(image_path_LH)
            image_path_HL = os.path.join(image_dir_HL, image)
            img_paths_HL.append(image_path_HL)
            image_path_HH = os.path.join(image_dir_HH, image)
            img_paths_HH.append(image_path_HH)

            mask_path = os.path.join(mask_dir, image)
            mask_paths.append(mask_path)

        self.img_paths_LL = img_paths_LL
        self.img_paths_LH = img_paths_LH
        self.img_paths_HL = img_paths_HL
        self.img_paths_HH = img_paths_HH
        self.mask_paths = mask_paths
        self.augmentation_1 = augmentation1
        self.normalize_LL = normalize_LL
        self.normalize_LH = normalize_LH
        self.normalize_HL = normalize_HL
        self.normalize_HH = normalize_HH
        self.kwargs = kwargs

    def __getitem__(self, index):

        img_path_LL = self.img_paths_LL[index]
        img_LL = Image.open(img_path_LL)
        img_LL = np.array(img_LL)

        img_path_LH = self.img_paths_LH[index]
        img_LH = Image.open(img_path_LH)
        img_LH = np.array(img_LH)

        img_path_HL = self.img_paths_HL[index]
        img_HL = Image.open(img_path_HL)
        img_HL = np.array(img_HL)

        img_path_HH = self.img_paths_HH[index]
        img_HH = Image.open(img_path_HH)
        img_HH = np.array(img_HH)

        mask_path = self.mask_paths[index]
        mask = Image.open(mask_path)
        mask = np.array(mask)

        augment_1 = self.augmentation_1(image=img_LL, mask=mask, imageLH=img_LH, imageHL=img_HL, imageHH=img_HH)
        img_LL = augment_1['image']
        img_LH = augment_1['imageLH']
        img_HL = augment_1['imageHL']
        img_HH = augment_1['imageHH']
        mask_1 = augment_1['mask']

        normalize_LL = self.normalize_LL(image=img_LL, mask=mask_1)
        img_LL = normalize_LL['image']
        mask_1 = normalize_LL['mask']
        mask_1 = mask_1.long()

        normalize_LH = self.normalize_LH(image=img_LH)
        img_LH = normalize_LH['image']

        normalize_HL = self.normalize_HL(image=img_HL)
        img_HL = normalize_HL['image']

        normalize_HH = self.normalize_HH(image=img_HH)
        img_HH = normalize_HH['image']

        sampel = {'image_LL': img_LL, 'image_LH': img_LH, 'image_HL': img_HL,  'image_HH': img_HH, 'mask': mask_1, 'ID': os.path.split(mask_path)[1]}

        return sampel

    def __len__(self):
        return len(self.img_paths_LL)


def imagefloder_wds(data_dir, data_transform_1, data_normalize_LL, data_normalize_LH, data_normalize_HL, data_normalize_HH, **kwargs):
    dataset = dataset_wds(data_dir=data_dir,
                           augmentation1=data_transform_1,
                           normalize_LL=data_normalize_LL,
                           normalize_LH=data_normalize_LH,
                           normalize_HL=data_normalize_HL,
                           normalize_HH=data_normalize_HH,
                           **kwargs
                           )
    return dataset

class dataset_aerial_lanenet(Dataset):
    def __init__(self, data_dir, augmentation1, normalize_1, normalize_l1, normalize_l2, normalize_l3, normalize_l4, **kwargs):
        super(dataset_aerial_lanenet, self).__init__()

        img_paths = []
        mask_paths = []
        image_dir = data_dir + '/image'
        mask_dir = data_dir + '/mask'

        for image in os.listdir(image_dir):

            image_path = os.path.join(image_dir, image)
            img_paths.append(image_path)

            mask_path = os.path.join(mask_dir, image)
            mask_paths.append(mask_path)

        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.augmentation_1 = augmentation1
        self.normalize_1 = normalize_1
        self.normalize_l4 = normalize_l4
        self.normalize_l3 = normalize_l3
        self.normalize_l2 = normalize_l2
        self.normalize_l1 = normalize_l1
        self.kwargs = kwargs

    def __getitem__(self, index):

        img_path = self.img_paths[index]
        img = Image.open(img_path)
        img = np.array(img)

        mask_path = self.mask_paths[index]
        mask = Image.open(mask_path)
        mask = np.array(mask)

        augment_1 = self.augmentation_1(image=img, mask=mask)
        img = augment_1['image']
        mask = augment_1['mask']

        img_ = np.array(Image.fromarray(img).convert('L'))
        _, l4, l3, l2, l1 = pywt.wavedec2(img_, 'db2', level=4)

        l4 = np.array(l4).transpose(1, 2, 0)
        l3 = np.array(l3).transpose(1, 2, 0)
        l2 = np.array(l2).transpose(1, 2, 0)
        l1 = np.array(l1).transpose(1, 2, 0)
        normalize_l4 = self.normalize_l4(image=l4)
        l4 = normalize_l4['image'].float()
        normalize_l3 = self.normalize_l3(image=l3)
        l3 = normalize_l3['image'].float()
        normalize_l2 = self.normalize_l2(image=l2)
        l2 = normalize_l2['image'].float()
        normalize_l1 = self.normalize_l1(image=l1)
        l1 = normalize_l1['image'].float()

        normalize_1 = self.normalize_1(image=img, mask=mask)
        img = normalize_1['image']
        mask = normalize_1['mask'].long()

        sampel = {'image': img, 'image_l1': l1, 'image_l2': l2, 'image_l3': l3, 'image_l4': l4, 'mask': mask, 'ID': os.path.split(mask_path)[1]}

        return sampel

    def __len__(self):
        return len(self.img_paths)


def imagefloder_aerial_lanenet(data_dir, data_transform, data_normalize, data_normalize_l1, data_normalize_l2, data_normalize_l3, data_normalize_l4, **kwargs):
    dataset = dataset_aerial_lanenet(data_dir=data_dir,
                           augmentation1=data_transform,
                           normalize_1=data_normalize,
                           normalize_l1=data_normalize_l1,
                           normalize_l2=data_normalize_l2,
                           normalize_l3=data_normalize_l3,
                           normalize_l4=data_normalize_l4,
                           **kwargs
                           )
    return dataset

================================================
FILE: dataload/dataset_3d.py
================================================
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
import torchio as tio
import SimpleITK as sitk
from torchio.data import UniformSampler, LabelSampler


class dataset_it(Dataset):
    def __init__(self, data_dir, input1, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
        super(dataset_it, self).__init__()

        self.subjects_1 = []

        image_dir_1 = data_dir + '/' + input1
        if sup:
            mask_dir = data_dir + '/mask'

        for i in os.listdir(image_dir_1):
            image_path_1 = os.path.join(image_dir_1, i)
            if sup:
                mask_path = os.path.join(mask_dir, i)
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), mask=tio.LabelMap(mask_path), ID=i)
            else:
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)

            self.subjects_1.append(subject_1)

        if num_images is not None:
            len_img_paths = len(self.subjects_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                self.subjects_1 = self.subjects_1[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                self.subjects_1 = self.subjects_1 * quotient
                self.subjects_1 += [self.subjects_1[i] for i in new_indices]

        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)

        self.queue_train_set_1 = tio.Queue(
            subjects_dataset=self.dataset_1,
            max_length=queue_length,
            samples_per_volume=samples_per_volume,
            sampler=UniformSampler(patch_size),
            # sampler=LabelSampler(patch_size),
            num_workers=num_workers,
            shuffle_subjects=shuffle_subjects,
            shuffle_patches=shuffle_patches
        )


class dataset_it_dtc(Dataset):
    def __init__(self, data_dir, input1, num_classes, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
        super(dataset_it_dtc, self).__init__()

        self.subjects_1 = []

        image_dir_1 = data_dir + '/' + input1
        if sup:
            mask_dir_1 = data_dir + '/mask'
            mask_dir_2 = data_dir + '/mask_sdf1'
            if num_classes == 3:
                mask_dir_3 = data_dir + '/mask_sdf2'

        for i in os.listdir(image_dir_1):
            image_path_1 = os.path.join(image_dir_1, i)
            if sup:
                mask_path_1 = os.path.join(mask_dir_1, i)
                mask_path_2 = os.path.join(mask_dir_2, i)
                if num_classes == 3:
                    mask_path_3 = os.path.join(mask_dir_3, i)
                    subject_1 = tio.Subject(
                        image=tio.ScalarImage(image_path_1),
                        mask=tio.LabelMap(mask_path_1),
                        mask2=tio.LabelMap(mask_path_2),
                        mask3=tio.LabelMap(mask_path_3),
                        ID=i)
                else:
                    subject_1 = tio.Subject(
                        image=tio.ScalarImage(image_path_1),
                        mask=tio.LabelMap(mask_path_1),
                        mask2=tio.LabelMap(mask_path_2),
                        ID=i)
            else:
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), ID=i)

            self.subjects_1.append(subject_1)

        if num_images is not None:
            len_img_paths = len(self.subjects_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                self.subjects_1 = self.subjects_1[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                self.subjects_1 = self.subjects_1 * quotient
                self.subjects_1 += [self.subjects_1[i] for i in new_indices]

        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)

        self.queue_train_set_1 = tio.Queue(
            subjects_dataset=self.dataset_1,
            max_length=queue_length,
            samples_per_volume=samples_per_volume,
            sampler=UniformSampler(patch_size),
            # sampler=LabelSampler(patch_size),
            num_workers=num_workers,
            shuffle_subjects=shuffle_subjects,
            shuffle_patches=shuffle_patches
        )

class dataset_iit(Dataset):
    def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
        super(dataset_iit, self).__init__()

        self.subjects_1 = []

        image_dir_1 = data_dir + '/' + input1
        image_dir_2 = data_dir + '/' + input2

        if sup:
            mask_dir_1 = data_dir + '/mask'

        for i in os.listdir(image_dir_1):
            image_path_1 = os.path.join(image_dir_1, i)
            image_path_2 = os.path.join(image_dir_2, i)
            if sup:
                mask_path_1 = os.path.join(mask_dir_1, i)
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), ID=i)
            else:
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)

            self.subjects_1.append(subject_1)

        if num_images is not None:
            len_img_paths = len(self.subjects_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                self.subjects_1 = self.subjects_1[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                self.subjects_1 = self.subjects_1 * quotient
                self.subjects_1 += [self.subjects_1[i] for i in new_indices]

        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)

        self.queue_train_set_1 = tio.Queue(
            subjects_dataset=self.dataset_1,
            max_length=queue_length,
            samples_per_volume=samples_per_volume,
            sampler=UniformSampler(patch_size),
            # sampler=LabelSampler(patch_size),
            num_workers=num_workers,
            shuffle_subjects=shuffle_subjects,
            shuffle_patches=shuffle_patches
        )


class dataset_iit_conresnet(Dataset):
    def __init__(self, data_dir, input1, input2, transform_1, queue_length=20, samples_per_volume=5, patch_size=128, num_workers=8, shuffle_subjects=True, shuffle_patches=True, sup=True, num_images=None):
        super(dataset_iit_conresnet, self).__init__()

        self.subjects_1 = []

        image_dir_1 = data_dir + '/' + input1
        image_dir_2 = data_dir + '/' + input2

        if sup:
            mask_dir_1 = data_dir + '/mask'
            mask_dir_2 = data_dir + '/mask_res'

        for i in os.listdir(image_dir_1):
            image_path_1 = os.path.join(image_dir_1, i)
            image_path_2 = os.path.join(image_dir_2, i)
            if sup:
                mask_path_1 = os.path.join(mask_dir_1, i)
                mask_path_2 = os.path.join(mask_dir_2, i)
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), mask=tio.LabelMap(mask_path_1), mask2=tio.LabelMap(mask_path_2), ID=i)
            else:
                subject_1 = tio.Subject(image=tio.ScalarImage(image_path_1), image2=tio.ScalarImage(image_path_2), ID=i)

            self.subjects_1.append(subject_1)

        if num_images is not None:
            len_img_paths = len(self.subjects_1)
            quotient = num_images // len_img_paths
            remainder = num_images % len_img_paths

            if num_images <= len_img_paths:
                self.subjects_1 = self.subjects_1[:num_images]
            else:
                rand_indices = torch.randperm(len_img_paths).tolist()
                new_indices = rand_indices[:remainder]

                self.subjects_1 = self.subjects_1 * quotient
                self.subjects_1 += [self.subjects_1[i] for i in new_indices]

        self.dataset_1 = tio.SubjectsDataset(self.subjects_1, transform=transform_1)

        self.queue_train_set_1 = tio.Queue(
            subjects_dataset=self.dataset_1,
            max_length=queue_length,
            samples_per_volume=samples_per_volume,
            sampler=UniformSampler(patch_size),
            # sampler=LabelSampler(patch_size),
            num_workers=num_workers,
            shuffle_subjects=shuffle_subjects,
            shuffle_patches=shuffle_patches
        )



================================================
FILE: loss/__init__.py
================================================


================================================
FILE: loss/loss_function.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import sys
from torch.nn.modules.loss import _Loss

class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
        super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index)
        self.aux = aux
        self.aux_weight = aux_weight

    def _aux_forward(self, output, target, **kwargs):
        # *preds, target = tuple(inputs)

        loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[0], target)
        for i in range(1, len(output)):
            aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(output[i], target)
            loss += self.aux_weight * aux_loss
        return loss

    def forward(self, output, target):
        # preds, target = tuple(inputs)
        # inputs = tuple(list(preds) + [target])
        if self.aux:
            return self._aux_forward(output, target)
        else:
            return super(MixSoftmaxCrossEntropyLoss, self).forward(output, target)

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """

    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target, valid_mask):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1).float()
        valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1).float()

        num = torch.sum(torch.mul(predict, target) * valid_mask, dim=1) * 2 + self.smooth
        den = torch.sum((predict.pow(self.p) + target.pow(self.p)) * valid_mask, dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input"""

    def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index
        self.aux = aux
        self.aux_weight = aux_weight

    def _base_forward(self, predict, target, valid_mask):

        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[-1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[..., i], valid_mask)
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss / target.shape[-1]

    def _aux_forward(self, output, target, **kwargs):
        # *preds, target = tuple(inputs)
        valid_mask = (target != self.ignore_index).long()
        target_one_hot = F.one_hot(torch.clamp_min(target, 0))
        loss = self._base_forward(output[0], target_one_hot, valid_mask)
        for i in range(1, len(output)):
            aux_loss = self._base_forward(output[i], target_one_hot, valid_mask)
            loss += self.aux_weight * aux_loss
        return loss

    def forward(self, output, target):
        # preds, target = tuple(inputs)
        # inputs = tuple(list(preds) + [target])
        if self.aux:
            return self._aux_forward(output, target)
        else:
            valid_mask = (target != self.ignore_index).long()
            target_one_hot = F.one_hot(torch.clamp_min(target, 0))
            return self._base_forward(output, target_one_hot, valid_mask)


def softmax_mse_loss(input_logits, target_logits, sigmoid=False):
    """Takes softmax on both sides and returns MSE loss
    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    if sigmoid:
        input_softmax = torch.sigmoid(input_logits)
        target_softmax = torch.sigmoid(target_logits)
    else:
        input_softmax = F.softmax(input_logits, dim=1)
        target_softmax = F.softmax(target_logits, dim=1)

    mse_loss = (input_softmax-target_softmax)**2
    return mse_loss


def entropy_loss(p, C=2):
    # p N*C*W*H*D
    y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / torch.tensor(np.log(C)).cuda()
    ent = torch.mean(y1)

    return ent

class BCELossBoud(nn.Module):
    def __init__(self, num_classes, weight=None, ignore_index=None, **kwargs):
        super(BCELossBoud, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index
        self.num_classes = num_classes
        self.criterion = nn.BCEWithLogitsLoss()

    def weighted_BCE_cross_entropy(self, output, target, weights = None):
        if weights is not None:
            assert len(weights) == 2
            output = torch.clamp(output, min=1e-3, max=1-1e-3)
            bce = weights[1] * (target * torch.log(output)) + weights[0] * ((1-target) * torch.log((1-output)))
        else:
            output = torch.clamp(output, min=1e-3, max=1 - 1e-3)
            bce = target * torch.log(output) + (1-target) * torch.log((1-output))
        return torch.neg(torch.mean(bce))

    def forward(self, predict, target):

        target_one_hot = F.one_hot(torch.clamp_min(target, 0), num_classes=self.num_classes).permute(0, 4, 1, 2, 3)
        predict = torch.softmax(predict, 1)

        bs, category, depth, width, heigt = target_one_hot.shape
        bce_loss = []
        for i in range(predict.shape[1]):
            pred_i = predict[:,i]
            targ_i = target_one_hot[:,i]
            tt = np.log(depth * width * heigt / (target_one_hot[:, i].cpu().data.numpy().sum()+1))
            bce_i = self.weighted_BCE_cross_entropy(pred_i, targ_i, weights=[1, tt])
            bce_loss.append(bce_i)

        bce_loss = torch.stack(bce_loss)
        total_loss = bce_loss.mean()
        return total_loss


class CustomKLLoss(_Loss):
    '''
    KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N
    N is the total number of image voxels
    '''

    def __init__(self, *args, **kwargs):
        super(CustomKLLoss, self).__init__()

    def forward(self, mean, std):
        return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean(
            torch.log(torch.mul(std, std))) - 1


def segmentation_loss(loss='CE', aux=False, **kwargs):

    if loss == 'dice' or loss == 'DICE':
        seg_loss = DiceLoss(aux=aux)
    elif loss == 'crossentropy' or loss == 'CE':
        seg_loss = MixSoftmaxCrossEntropyLoss(aux=aux)
    elif loss == 'bce':
        seg_loss = nn.BCELoss(size_average=True)
    elif loss == 'bcebound':
        seg_loss = BCELossBoud(num_classes=kwargs['num_classes'])
    else:
        print('sorry, the loss you input is not supported yet')
        sys.exit()

    return seg_loss


# if __name__ == '__main__':
#     from models import *
#     criterion = segmentation_loss(loss='LOVASZ')
#     # criterion = nn.CrossEntropyLoss()
#
#     model = unet(1, 2)
#     model.eval()
#     input = torch.rand(3, 1, 128, 128)
#     mask = torch.zeros(3, 128, 128).long()
#
#     mask[:, 40:100, 30:60] = 1
#     output = model(input)
#
#     loss = criterion(output, mask)
#     print(loss)
#     # loss.requires_grad_(True)
#     # loss.backward()



================================================
FILE: models/__init__.py
================================================
# 2d
from .networks_2d.xnet import XNet, XNet_1_1_m, XNet_1_2_m, XNet_2_1_m, XNet_3_2_m, XNet_2_3_m, XNet_3_3_m, XNet_sb
from .networks_2d.unet import unet, r2_unet, attention_unet
from .networks_2d.unet_plusplus import unet_plusplus
from .networks_2d.hrnet import hrnet18, hrnet32, hrnet48, hrnet64
from .networks_2d.swinunet import swinunet
from .networks_2d.unet_urpc import unet_urpc
from .networks_2d.unet_cct import unet_cct
from .networks_2d.resunet import res_unet
from .networks_2d.resunet_plusplus import res_unet_plusplus
from .networks_2d.u2net import u2net, u2net_small
from .networks_2d.unet_3plus import unet_3plus, unet_3plus_ds, unet_3plus_ds_cgm
from .networks_2d.wavesnet import wsegnet_vgg16_bn
from .networks_2d.mwcnn import mwcnn
from .networks_2d.aerial_lanenet import Aerial_LaneNet
from .networks_2d.wds import WDS

# 3d
from .networks_3d.unet3d import unet3d, unet3d_min
from .networks_3d.vnet import vnet
from .networks_3d.res_unet3d import res_unet3d
from .networks_3d.transbts import transbts
from .networks_3d.cotr import cotr
from .networks_3d.dmfnet import dmfnet
from .networks_3d.conresnet import conresnet
from .networks_3d.espnet3d import espnet3d
from .networks_3d.unetr import unertr
from .networks_3d.unet3d_urpc import unet3d_urpc
from .networks_3d.unet3d_cct import unet3d_cct, unet3d_cct_min
from .networks_3d.unet3d_dtc import unet3d_dtc
from .networks_3d.xnet3d import xnet3d
from .networks_3d.vnet_cct import vnet_cct
from .networks_3d.vnet_dtc import vnet_dtc

================================================
FILE: models/getnetwork.py
================================================
import sys
from models import *
import torch.nn as nn

def get_network(network, in_channels, num_classes, **kwargs):

    # 2d networks
    if network == 'xnet':
        net = XNet(in_channels, num_classes)
    elif network == 'xnet_sb':
        net = XNet_sb(in_channels, num_classes)
    elif network == 'xnet_1_1_m':
        net = XNet_1_1_m(in_channels, num_classes)
    elif network == 'xnet_1_2_m':
        net = XNet_1_2_m(in_channels, num_classes)
    elif network == 'xnet_2_1_m':
        net = XNet_2_1_m(in_channels, num_classes)
    elif network == 'xnet_3_2_m':
        net = XNet_3_2_m(in_channels, num_classes)
    elif network == 'xnet_2_3_m':
        net = XNet_2_3_m(in_channels, num_classes)
    elif network == 'xnet_3_3_m':
        net = XNet_3_3_m(in_channels, num_classes)
    elif network == 'unet':
        net = unet(in_channels, num_classes)
    elif network == 'unet_plusplus' or network == 'unet++':
        net = unet_plusplus(in_channels, num_classes)
    elif network == 'r2unet':
        net = r2_unet(in_channels, num_classes)
    elif network == 'attunet':
        net = attention_unet(in_channels, num_classes)
    elif network == 'hrnet18':
        net = hrnet18(in_channels, num_classes)
    elif network == 'hrnet48':
        net = hrnet48(in_channels, num_classes)
    elif network == 'resunet':
        net = res_unet(in_channels, num_classes)
    elif network == 'resunet++':
        net = res_unet_plusplus(in_channels, num_classes)
    elif network == 'u2net':
        net = u2net(in_channels, num_classes)
    elif network == 'u2net_s':
        net = u2net_small(in_channels, num_classes)
    elif network == 'unet3+':
        net = unet_3plus(in_channels, num_classes)
    elif network == 'unet3+_ds':
        net = unet_3plus_ds(in_channels, num_classes)
    elif network == 'unet3+_ds_cgm':
        net = unet_3plus_ds_cgm(in_channels, num_classes)
    elif network == 'swinunet':
        net = swinunet(num_classes, 224)  # img_size = 224
    elif network == 'unet_urpc':
        net = unet_urpc(in_channels, num_classes)
    elif network == 'unet_cct':
        net = unet_cct(in_channels, num_classes)
    elif network == 'wavesnet':
        net = wsegnet_vgg16_bn(in_channels, num_classes)
    elif network == 'mwcnn':
        net = mwcnn(in_channels, num_classes)
    elif network == 'alnet':
        net = Aerial_LaneNet(in_channels, num_classes)
    elif network == 'wds':
        net = WDS(in_channels, num_classes)

    # 3d networks
    elif network == 'xnet3d':
        net = xnet3d(in_channels, num_classes)
    elif network == 'unet3d':
        net = unet3d(in_channels, num_classes)
    elif network == 'unet3d_min':
        net = unet3d_min(in_channels, num_classes)
    elif network == 'unet3d_urpc':
        net = unet3d_urpc(in_channels, num_classes)
    elif network == 'unet3d_cct':
        net = unet3d_cct(in_channels, num_classes)
    elif network == 'unet3d_cct_min':
        net = unet3d_cct_min(in_channels, num_classes)
    elif network == 'unet3d_dtc':
        net = unet3d_dtc(in_channels, num_classes)
    elif network == 'vnet':
        net = vnet(in_channels, num_classes)
    elif network == 'vnet_cct':
        net = vnet_cct(in_channels, num_classes)
    elif network == 'vnet_dtc':
        net = vnet_dtc(in_channels, num_classes)
    elif network == 'resunet3d':
        net = res_unet3d(in_channels, num_classes)
    elif network == 'conresnet':
        net = conresnet(in_channels, num_classes, img_shape=kwargs['img_shape'])
    elif network == 'espnet3d':
        net = espnet3d(in_channels, num_classes)
    elif network == 'dmfnet':
        net = dmfnet(in_channels, num_classes)
    elif network == 'transbts':
        net = transbts(in_channels, num_classes, img_shape=kwargs['img_shape'])
    elif network == 'cotr':
        net = cotr(in_channels, num_classes)
    elif network == 'unertr':
        net = unertr(in_channels, num_classes, img_shape=kwargs['img_shape'])
    else:
        print('the network you have entered is not supported yet')
        sys.exit()
    return net


================================================
FILE: models/networks_2d/__init__.py
================================================


================================================
FILE: models/networks_2d/aerial_lanenet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np


class basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(basic_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True))
    def forward(self, x):
        x = self.block(x)
        return x

class Aerial_LaneNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Aerial_LaneNet, self).__init__()

        l1, l2, l3, l4, l5 = 64, 128, 256, 512, 512
        dropout = 0.2

        # e1
        self.conv1_1 = basic_block(in_channels, l1)
        self.conv1_2 = basic_block(l1, l1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        # e2
        self.conv2_1 = basic_block(l1+3, l2)
        self.conv2_2 = basic_block(l2, l2)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        # e3
        self.conv3_1 = basic_block(l2+3, l3)
        self.conv3_2 = basic_block(l3, l3)
        self.conv3_3 = basic_block(l3, l3)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        # e4
        self.conv4_1 = basic_block(l3+3, l4)
        self.conv4_2 = basic_block(l4, l4)
        self.conv4_3 = basic_block(l4, l4)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        # e5
        self.conv5_1 = basic_block(l4+3, l5)
        self.conv5_2 = basic_block(l5, l5)
        self.conv5_3 = basic_block(l5, l5)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        # e6
        self.conv6_1 = basic_block(l5, 4096)
        self.drop6_1 = nn.Dropout2d(dropout)
        self.conv6_2 = basic_block(4096, 4096)
        self.drop6_2 = nn.Dropout2d(dropout)
        self.conv6_3 = nn.ConvTranspose2d(4096, l5, kernel_size=4, stride=2, padding=1, bias=False)

        # d4
        self.conv4_4 = basic_block(2*l5, l5)
        self.drop4_4 = nn.Dropout2d(dropout)
        self.conv4_5 = nn.ConvTranspose2d(l5, l3, kernel_size=4, stride=2, padding=1, bias=False)

        # d3
        self.conv3_4 = basic_block(2*l3, l3)
        self.drop3_4 = nn.Dropout2d(dropout)
        self.conv3_5 = nn.ConvTranspose2d(l3, l2, kernel_size=4, stride=2, padding=1, bias=False)

        # d2
        self.conv2_4 = basic_block(2*l2, l2)
        self.drop2_4 = nn.Dropout2d(dropout)
        self.conv2_5 = nn.ConvTranspose2d(l2, l1, kernel_size=4, stride=2, padding=1, bias=False)

        # d1
        self.conv1_3 = basic_block(2*l1, l1)
        self.drop1_3 = nn.Dropout2d(dropout)
        self.conv1_4 = nn.ConvTranspose2d(l1, num_classes, kernel_size=4, stride=2, padding=1, bias=False)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, x_wavelet_1, x_wavelet_2, x_wavelet_3, x_wavelet_4):

        x1 = self.conv1_1(x)
        x1 = self.conv1_2(x1)
        x1 = self.pool1(x1)

        x2 = torch.cat((x1, x_wavelet_1), dim=1)
        x2 = self.conv2_1(x2)
        x2 = self.conv2_2(x2)
        x2 = self.pool2(x2)

        x3 = torch.cat((x2, x_wavelet_2), dim=1)
        x3 = self.conv3_1(x3)
        x3 = self.conv3_2(x3)
        x3 = self.conv3_3(x3)
        x3 = self.pool3(x3)

        x4 = torch.cat((x3, x_wavelet_3), dim=1)
        x4 = self.conv4_1(x4)
        x4 = self.conv4_2(x4)
        x4 = self.conv4_3(x4)
        x4 = self.pool4(x4)

        x5 = torch.cat((x4, x_wavelet_4), dim=1)
        x5 = self.conv5_1(x5)
        x5 = self.conv5_2(x5)
        x5 = self.conv5_3(x5)
        x5 = self.pool5(x5)

        x6 = self.conv6_1(x5)
        x6 = self.drop6_1(x6)
        x6 = self.conv6_2(x6)
        x6 = self.drop6_2(x6)
        x6 = self.conv6_3(x6)

        x5 = torch.cat((x6, x4), dim=1)
        x5 = self.conv4_4(x5)
        x5 = self.drop4_4(x5)
        x5 = self.conv4_5(x5)

        x4 = torch.cat((x5, x3), dim=1)
        x4 = self.conv3_4(x4)
        x4 = self.drop3_4(x4)
        x4 = self.conv3_5(x4)

        x3 = torch.cat((x4, x2), dim=1)
        x3 = self.conv2_4(x3)
        x3 = self.drop2_4(x3)
        x3 = self.conv2_5(x3)

        x2 = torch.cat((x3, x1), dim=1)
        x2 = self.conv1_3(x2)
        x2 = self.drop1_3(x2)
        x2 = self.conv1_4(x2)

        return x2

# if __name__ == '__main__':
#     from loss.loss_function import segmentation_loss
#     criterion = segmentation_loss('dice', False)
#     mask = torch.ones(2, 128, 128).long()
#     model = Aerial_LaneNet(1, 5)
#     model.train()
#     input1 = torch.rand(2, 1, 128, 128)
#     input2 = torch.rand(2, 3, 64, 64)
#     input3 = torch.rand(2, 3, 32, 32)
#     input4 = torch.rand(2, 3, 16, 16)
#     input5 = torch.rand(2, 3, 8, 8)
#
#     y = model(input1, input2, input3, input4, input5)
#     loss_train = criterion(y, mask)
#     loss_train.backward()
#     # print(output)
#     print(y.data.cpu().numpy().shape)
#     print(loss_train)

================================================
FILE: models/networks_2d/hrnet.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import functools

import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
from torch.nn import init

try:
    from .sync_bn.inplace_abn.bn import InPlaceABNSync
    BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
except:
    BatchNorm2d = nn.BatchNorm2d

BN_MOMENTUM = 0.01
logger = logging.getLogger(__name__)


model_urls = {
    'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth',
}

import sys
try:
    from urllib import urlretrieve
except ImportError:
    from urllib.request import urlretrieve


def load_url(url, model_dir='./pretrained', map_location=None):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    filename = url.split('/')[-1]
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        urlretrieve(url, cached_file)
    return torch.load(cached_file, map_location=map_location)

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=False)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = BatchNorm2d(planes * self.expansion,
                               momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out


class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(inplace=False)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
                self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(num_channels[branch_index] * block.expansion,
                            momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3,
                                            momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3,
                                            momentum=BN_MOMENTUM),
                                nn.ReLU(inplace=False)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []
        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                elif j > i:
                    width_output = x[i].shape[-1]
                    height_output = x[i].shape[-2]
                    y = y + F.interpolate(
                        self.fuse_layers[i][j](x[j]),
                        size=[height_output, width_output],
                        mode='bilinear',align_corners=False)
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse


blocks_dict = {
    'BASIC': BasicBlock,
    'BOTTLENECK': Bottleneck
}


class HighResolutionNet(nn.Module):

    def __init__(self, in_channels, extra, num_classes,**kwargs):
        super(HighResolutionNet, self).__init__()

        # stem net
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=False)

        self.stage1_cfg = extra['STAGE1']
        num_channels = self.stage1_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage1_cfg['BLOCK']]
        num_blocks = self.stage1_cfg['NUM_BLOCKS']
        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
        stage1_out_channel = block.expansion * num_channels

        self.stage2_cfg = extra['STAGE2']
        num_channels = self.stage2_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage2_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition1 = self._make_transition_layer(
            [stage1_out_channel], num_channels)
        self.stage2, pre_stage_channels = self._make_stage(
            self.stage2_cfg, num_channels)

        self.stage3_cfg = extra['STAGE3']
        num_channels = self.stage3_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage3_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition2 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage3, pre_stage_channels = self._make_stage(
            self.stage3_cfg, num_channels)

        self.stage4_cfg = extra['STAGE4']
        num_channels = self.stage4_cfg['NUM_CHANNELS']
        block = blocks_dict[self.stage4_cfg['BLOCK']]
        num_channels = [
            num_channels[i] * block.expansion for i in range(len(num_channels))]
        self.transition3 = self._make_transition_layer(
            pre_stage_channels, num_channels)
        self.stage4, pre_stage_channels = self._make_stage(
            self.stage4_cfg, num_channels, multi_scale_output=True)

        last_inp_channels = int(np.sum(pre_stage_channels))

        self.last_layer = nn.Sequential(
            nn.Conv2d(
                in_channels=last_inp_channels,
                out_channels=last_inp_channels,
                kernel_size=1,
                stride=1,
                padding=0),
            BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                in_channels=last_inp_channels,
                out_channels=num_classes,
                kernel_size=extra['FINAL_CONV_KERNEL'],
                stride=1,
                padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)
        )

    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(nn.Sequential(
                        nn.Conv2d(num_channels_pre_layer[i],
                                  num_channels_cur_layer[i],
                                  3,
                                  1,
                                  1,
                                  bias=False),
                        BatchNorm2d(
                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=False)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i + 1 - num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i - num_branches_pre else inchannels
                    conv3x3s.append(nn.Sequential(
                        nn.Conv2d(
                            inchannels, outchannels, 3, 2, 1, bias=False),
                        BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=False)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(inplanes, planes, stride, downsample))
        inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(inplanes, planes))

        return nn.Sequential(*layers)

    def _make_stage(self, layer_config, num_inchannels,
                    multi_scale_output=True):
        num_modules = layer_config['NUM_MODULES']
        num_branches = layer_config['NUM_BRANCHES']
        num_blocks = layer_config['NUM_BLOCKS']
        num_channels = layer_config['NUM_CHANNELS']
        block = blocks_dict[layer_config['BLOCK']]
        fuse_method = layer_config['FUSE_METHOD']

        modules = []
        for i in range(num_modules):
            # multi_scale_output is only used last module
            if not multi_scale_output and i == num_modules - 1:
                reset_multi_scale_output = False
            else:
                reset_multi_scale_output = True
            modules.append(
                HighResolutionModule(num_branches,
                                     block,
                                     num_blocks,
                                     num_inchannels,
                                     num_channels,
                                     fuse_method,
                                     reset_multi_scale_output)
            )
            num_inchannels = modules[-1].get_num_inchannels()

        return nn.Sequential(*modules), num_inchannels

    def forward(self, x):
        size =x.shape[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.stage2_cfg['NUM_BRANCHES']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)

        x_list = []
        for i in range(self.stage3_cfg['NUM_BRANCHES']):
            if self.transition2[i] is not None:
                if i < self.stage2_cfg['NUM_BRANCHES']:
                    x_list.append(self.transition2[i](y_list[i]))
                else:
                    x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)

        x_list = []
        for i in range(self.stage4_cfg['NUM_BRANCHES']):
            if self.transition3[i] is not None:
                if i < self.stage3_cfg['NUM_BRANCHES']:
                    x_list.append(self.transition3[i](y_list[i]))
                else:
                    x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])
        x = self.stage4(x_list)

        # Upsampling
        x0_h, x0_w = x[0].size(2), x[0].size(3)
        x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
        x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
        x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False)

        x = torch.cat([x[0], x1, x2, x3], 1)

        x = self.last_layer(x)
        x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
        # outputs = []
        # outputs.append(x)
        # return outputs
        return x
    def init_weights(self, pretrained='', ):
        logger.info('=> init weights from normal distribution')
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
            elif isinstance(m, InPlaceABNSync):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if os.path.isfile(pretrained):
            pretrained_dict = torch.load(pretrained)
            logger.info('=> loading pretrained model {}'.format(pretrained))
            model_dict = self.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items()
                               if k in model_dict.keys()}
            for k, _ in pretrained_dict.items():
                logger.info(
                    '=> loading {} pretrained model {}'.format(k, pretrained))
            model_dict.update(pretrained_dict)
            self.load_state_dict(model_dict)

# class HRNet(nn.Module):
#     def __init__(self, in_channels, extra, num_classes, **kwargs):
#         super(HRNet, self).__init__()
#         self.branch1 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)
#         self.branch2 = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra)
#
#     def forward(self, data, step=1):
#         if not self.training:
#             pred1 = self.branch1(data)
#             return pred1
#
#         if step == 1:
#             return self.branch1(data)
#         elif step == 2:
#             return self.branch2(data)

extra_18 = {
            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (18, 36), 'FUSE_METHOD': 'SUM'},
            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (18, 36, 72), 'FUSE_METHOD': 'SUM'},
            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (18, 36, 72, 144), 'FUSE_METHOD': 'SUM'},
            'FINAL_CONV_KERNEL': 1
            }

extra_32 = {
            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (32, 64), 'FUSE_METHOD': 'SUM'},
            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (32, 64, 128), 'FUSE_METHOD': 'SUM'},
            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (32, 64, 128, 256), 'FUSE_METHOD': 'SUM'},
            'FINAL_CONV_KERNEL': 1
            }

extra_48 = {
            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'},
            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'},
            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'},
            'FINAL_CONV_KERNEL': 1
            }

extra_64 = {
            'STAGE1': {'NUM_MODULES': 1, 'NUM_BRANCHES': 1, 'BLOCK': 'BOTTLENECK', 'NUM_BLOCKS': (4), 'NUM_CHANNELS': (64), 'FUSE_METHOD': 'SUM'},
            'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (64, 128), 'FUSE_METHOD': 'SUM'},
            'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (64, 128, 256), 'FUSE_METHOD': 'SUM'},
            'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (64, 128, 256, 512), 'FUSE_METHOD': 'SUM'},
            'FINAL_CONV_KERNEL': 1
            }

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

# def hrnet18(in_channels, num_classes):
#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)
#     return model
#
# def hrnet32(in_channels, num_classes):
#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)
#     return model
#
# def hrnet48(in_channels, num_classes):
#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)
#     return model
#
# def hrnet64(in_channels, num_classes):
#     model = HRNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)
#     return model

def hrnet18(in_channels, num_classes):
    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_18)
    init_weights(model, 'kaiming')
    return model

def hrnet32(in_channels, num_classes):
    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_32)
    init_weights(model, 'kaiming')
    return model

def hrnet48(in_channels, num_classes):
    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_48)
    init_weights(model, 'kaiming')
    return model

def hrnet64(in_channels, num_classes):
    model = HighResolutionNet(in_channels=in_channels, num_classes=num_classes, extra=extra_64)
    init_weights(model, 'kaiming')
    return model


# if __name__ == '__main__':
#     model = hrnet48(1,10)

    # total = sum([param.nelement() for param in model.parameters()])
    # from thop import profile,clever_format
    #
    # input = torch.randn(1, 1, 128, 128)
    # flops, params = profile(model, inputs=(input, ))
    # macs, params = clever_format([flops, params], "%.3f")
    # print(macs)
    # print(params)
    # print(total)
    # model.eval()
    # input = torch.rand(1,1,256,256)
    # output = model(input)
    # output = output[0].data.cpu().numpy()
    # print(output)
    # print(output.shape)



================================================
FILE: models/networks_2d/mwcnn.py
================================================
import torch
import torch.nn as nn
import scipy.io as sio
import math
import torch.nn.functional as F
from torch.autograd import Variable


def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2) + dilation - 1, bias=bias, dilation=dilation)


def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, groups=groups)


# def shuffle_channel()

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.size()

    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


def pixel_down_shuffle(x, downsacale_factor):
    batchsize, num_channels, height, width = x.size()

    out_height = height // downsacale_factor
    out_width = width // downsacale_factor
    input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width,
                                     downsacale_factor)

    num_channels *= downsacale_factor ** 2
    unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()

    return unshuffle_out.view(batchsize, num_channels, out_height, out_width)


def sp_init(x):
    x01 = x[:, :, 0::2, :]
    x02 = x[:, :, 1::2, :]
    x_LL = x01[:, :, :, 0::2]
    x_HL = x02[:, :, :, 0::2]
    x_LH = x01[:, :, :, 1::2]
    x_HH = x02[:, :, :, 1::2]

    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)


def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4

    return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)


def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    # print([in_batch, in_channel, in_height, in_width])
    out_batch, out_channel, out_height, out_width = in_batch, int(
        in_channel / (r ** 2)), r * in_height, r * in_width
    x1 = x[:, 0:out_channel, :, :] / 2
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2

    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
    # h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

    return h


class Channel_Shuffle(nn.Module):
    def __init__(self, conv_groups):
        super(Channel_Shuffle, self).__init__()
        self.conv_groups = conv_groups
        self.requires_grad = False

    def forward(self, x):
        return channel_shuffle(x, self.conv_groups)


class SP(nn.Module):
    def __init__(self):
        super(SP, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return sp_init(x)


class Pixel_Down_Shuffle(nn.Module):
    def __init__(self):
        super(Pixel_Down_Shuffle, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return pixel_down_shuffle(x, 2)


class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return dwt_init(x)


class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)


class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False
        if sign == -1:
            self.create_graph = False
            self.volatile = True


class MeanShift2(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift2, self).__init__(4, 4, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(4).view(4, 4, 1, 1)
        self.weight.data.div_(std.view(4, 1, 1, 1))
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False
        if sign == -1:
            self.volatile = True


class BasicBlock(nn.Sequential):
    def __init__(
            self, in_channels, out_channels, kernel_size, stride=1, bias=False,
            bn=False, act=nn.ReLU(True)):

        m = [nn.Conv2d(
            in_channels, out_channels, kernel_size,
            padding=(kernel_size // 2), stride=stride, bias=bias)
        ]
        if bn: m.append(nn.BatchNorm2d(out_channels))
        if act is not None: m.append(act)
        super(BasicBlock, self).__init__(*m)


class BBlock(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
        super(BBlock, self).__init__()
        m = []
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x).mul(self.res_scale)
        return x


class DBlock_com(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_com, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class DBlock_inv(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_inv, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class DBlock_com1(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_com1, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class DBlock_inv1(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_inv1, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=1))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class DBlock_com2(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_com2, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class DBlock_inv2(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DBlock_inv2, self).__init__()
        m = []

        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2))
        if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x)
        return x


class ShuffleBlock(nn.Module):
    def __init__(
            self, conv, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1, conv_groups=1):
        super(ShuffleBlock, self).__init__()
        m = []
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
        m.append(Channel_Shuffle(conv_groups))
        if bn: m.append(nn.BatchNorm2d(out_channels))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x).mul(self.res_scale)
        return x


class DWBlock(nn.Module):
    def __init__(
            self, conv, conv1, in_channels, out_channels, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(DWBlock, self).__init__()
        m = []
        m.append(conv(in_channels, out_channels, kernel_size, bias=bias))
        if bn: m.append(nn.BatchNorm2d(out_channels))
        m.append(act)

        m.append(conv1(in_channels, out_channels, 1, bias=bias))
        if bn: m.append(nn.BatchNorm2d(out_channels))
        m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        x = self.body(x).mul(self.res_scale)
        return x


class ResBlock(nn.Module):
    def __init__(
            self, conv, n_feat, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if i == 0: m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class Block(nn.Module):
    def __init__(
            self, conv, n_feat, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(Block, self).__init__()
        m = []
        for i in range(4):
            m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if i == 0: m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        # res += x

        return res


class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4 * n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: m.append(nn.BatchNorm2d(n_feat))
                if act: m.append(act())
        elif scale == 3:
            m.append(conv(n_feat, 9 * n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn: m.append(nn.BatchNorm2d(n_feat))
            if act: m.append(act())
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)


class MWCNN(nn.Module):
    def __init__(self, in_channels, num_classes, conv=default_conv):
        super(MWCNN, self).__init__()
        kernel_size = 3
        self.scale_idx = 0
        n_feats = 64
        act = nn.ReLU(True)

        self.DWT = DWT()
        self.IWT = IWT()

        n = 1
        m_head = [BBlock(conv, in_channels, n_feats, kernel_size, act=act)]
        d_l0 = []
        d_l0.append(DBlock_com1(conv, n_feats, n_feats, kernel_size, act=act, bn=False))


        d_l1 = [BBlock(conv, n_feats * 4, n_feats * 2, kernel_size, act=act, bn=False)]
        d_l1.append(DBlock_com1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False))

        d_l2 = []
        d_l2.append(BBlock(conv, n_feats * 8, n_feats * 4, kernel_size, act=act, bn=False))
        d_l2.append(DBlock_com1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False))
        pro_l3 = []
        pro_l3.append(BBlock(conv, n_feats * 16, n_feats * 8, kernel_size, act=act, bn=False))
        pro_l3.append(DBlock_com(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))
        pro_l3.append(DBlock_inv(conv, n_feats * 8, n_feats * 8, kernel_size, act=act, bn=False))
        pro_l3.append(BBlock(conv, n_feats * 8, n_feats * 16, kernel_size, act=act, bn=False))

        i_l2 = [DBlock_inv1(conv, n_feats * 4, n_feats * 4, kernel_size, act=act, bn=False)]
        i_l2.append(BBlock(conv, n_feats * 4, n_feats * 8, kernel_size, act=act, bn=False))

        i_l1 = [DBlock_inv1(conv, n_feats * 2, n_feats * 2, kernel_size, act=act, bn=False)]
        i_l1.append(BBlock(conv, n_feats * 2, n_feats * 4, kernel_size, act=act, bn=False))

        i_l0 = [DBlock_inv1(conv, n_feats, n_feats, kernel_size, act=act, bn=False)]

        m_tail = [conv(n_feats, num_classes, kernel_size)]

        self.head = nn.Sequential(*m_head)
        self.d_l2 = nn.Sequential(*d_l2)
        self.d_l1 = nn.Sequential(*d_l1)
        self.d_l0 = nn.Sequential(*d_l0)
        self.pro_l3 = nn.Sequential(*pro_l3)
        self.i_l2 = nn.Sequential(*i_l2)
        self.i_l1 = nn.Sequential(*i_l1)
        self.i_l0 = nn.Sequential(*i_l0)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x0 = self.d_l0(self.head(x))
        x1 = self.d_l1(self.DWT(x0))
        x2 = self.d_l2(self.DWT(x1))
        x_ = self.IWT(self.pro_l3(self.DWT(x2))) + x2
        x_ = self.IWT(self.i_l2(x_)) + x1
        x_ = self.IWT(self.i_l1(x_)) + x0
        x_ = self.tail(self.i_l0(x_))

        return x_

    def set_scale(self, scale_idx):
        self.scale_idx = scale_idx


def mwcnn(in_channels, num_classes):
    model = MWCNN(in_channels, num_classes)
    return model

# if __name__ == '__main__':
#     from loss.loss_function import segmentation_loss
#     criterion = segmentation_loss('dice', False)
#     mask = torch.ones(2, 128, 128).long()
#     model = mwcnn(1, 2)
#     model.train()
#     input1 = torch.rand(2, 1, 128, 128)
#     y = model(input1)
#     loss_train = criterion(y, mask)
#     loss_train.backward()
#     # print(output)
#     print(y.data.cpu().numpy().shape)
#     print(loss_train)

================================================
FILE: models/networks_2d/resunet.py
================================================
import torch
import torch.nn as nn
from torch.nn import init

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)


class Upsample(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(Upsample, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)


class ResUnet(nn.Module):
    def __init__(self, in_channels, num_classes, filters=[64, 128, 256, 512]):
        super(ResUnet, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)
        )

        self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
        self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.bridge = ResidualConv(filters[2], filters[3], 2, 1)

        self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)

        self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)

        self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)

        self.output_layer = nn.Conv2d(filters[0], num_classes, 1, 1)

    def forward(self, x):
        # Encode
        x1 = self.input_layer(x) + self.input_skip(x)
        x2 = self.residual_conv_1(x1)
        x3 = self.residual_conv_2(x2)
        # Bridge
        x4 = self.bridge(x3)
        # Decode
        x4 = self.upsample_1(x4)
        x5 = torch.cat([x4, x3], dim=1)

        x6 = self.up_residual_conv1(x5)

        x6 = self.upsample_2(x6)
        x7 = torch.cat([x6, x2], dim=1)

        x8 = self.up_residual_conv2(x7)

        x8 = self.upsample_3(x8)
        x9 = torch.cat([x8, x1], dim=1)

        x10 = self.up_residual_conv3(x9)

        output = self.output_layer(x10)

        return output


def res_unet(in_channels, num_classes):
    model = ResUnet(in_channels, num_classes)
    init_weights(model, 'kaiming')
    return model


# if __name__ == '__main__':
#     model = res_unet(1,10)
#     model.eval()
#     input = torch.rand(2,1,128,128)
#     output = model(input)
#     output = output.data.cpu().numpy()
#     # print(output)
#     print(output.shape)

================================================
FILE: models/networks_2d/resunet_plusplus.py
================================================
import torch.nn as nn
import torch
from torch.nn import init

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)


class Upsample(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(Upsample, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)

class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ASPP(nn.Module):
    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )

        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class Upsample_(nn.Module):
    def __init__(self, scale=2):
        super(Upsample_, self).__init__()

        self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale, align_corners=True)

    def forward(self, x):
        return self.upsample(x)

class AttentionBlock(nn.Module):
    def __init__(self, input_encoder, input_decoder, output_dim):
        super(AttentionBlock, self).__init__()

        self.conv_encoder = nn.Sequential(
            nn.BatchNorm2d(input_encoder),
            nn.ReLU(),
            nn.Conv2d(input_encoder, output_dim, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        self.conv_decoder = nn.Sequential(
            nn.BatchNorm2d(input_decoder),
            nn.ReLU(),
            nn.Conv2d(input_decoder, output_dim, 3, padding=1),
        )

        self.conv_attn = nn.Sequential(
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, 1, 1),
        )

    def forward(self, x1, x2):
        out = self.conv_encoder(x1) + self.conv_decoder(x2)
        out = self.conv_attn(out)
        return out * x2



class ResUnetPlusPlus(nn.Module):
    def __init__(self, in_channels, num_classes, filters=[32, 64, 128, 256, 512]):
        super(ResUnetPlusPlus, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1)
        )

        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])

        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)

        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])

        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])

        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)

        self.aspp_bridge = ASPP(filters[3], filters[4])

        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
        self.upsample1 = Upsample_(2)
        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)

        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
        self.upsample2 = Upsample_(2)
        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)

        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
        self.upsample3 = Upsample_(2)
        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)

        self.aspp_out = ASPP(filters[1], filters[0])

        self.output_layer = nn.Conv2d(filters[0], num_classes, 1)

    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)

        x2 = self.squeeze_excite1(x1)
        x2 = self.residual_conv1(x2)

        x3 = self.squeeze_excite2(x2)
        x3 = self.residual_conv2(x3)

        x4 = self.squeeze_excite3(x3)
        x4 = self.residual_conv3(x4)

        x5 = self.aspp_bridge(x4)

        x6 = self.attn1(x3, x5)
        x6 = self.upsample1(x6)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up_residual_conv1(x6)

        x7 = self.attn2(x2, x6)
        x7 = self.upsample2(x7)
        x7 = torch.cat([x7, x2], dim=1)
        x7 = self.up_residual_conv2(x7)

        x8 = self.attn3(x1, x7)
        x8 = self.upsample3(x8)
        x8 = torch.cat([x8, x1], dim=1)
        x8 = self.up_residual_conv3(x8)

        x9 = self.aspp_out(x8)
        out = self.output_layer(x9)

        return out

def res_unet_plusplus(in_channels, num_classes):
    model = ResUnetPlusPlus(in_channels, num_classes)
    init_weights(model, 'kaiming')
    return model


# if __name__ == '__main__':
#     model = res_unet_plusplus(1,10)
#     model.eval()
#     input = torch.rand(2,1,128,128)
#     output = model(input)
#     output = output.data.cpu().numpy()
#     # print(output)
#     print(output.shape)

================================================
FILE: models/networks_2d/swinunet.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
import torch.utils.checkpoint as checkpoint


logger = logging.getLogger(__name__)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
    ).view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size,
                     window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - \
            coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(
            1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - \
            1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index",
                             relative_position
Download .txt
gitextract_k1uwhl9g/

├── .idea/
│   ├── XNet.iml
│   ├── deployment.xml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── vcs.xml
│   └── workspace.xml
├── LICENSE
├── README.md
├── config/
│   ├── __init__.py
│   ├── augmentation/
│   │   ├── __init__.py
│   │   └── online_aug.py
│   ├── dataset_config/
│   │   ├── __init__.py
│   │   └── dataset_cfg.py
│   ├── eval_config/
│   │   ├── __init__.py
│   │   └── eval.py
│   ├── ramps/
│   │   ├── __init__.py
│   │   └── ramps.py
│   ├── train_test_config/
│   │   ├── __init__.py
│   │   └── train_test_config.py
│   ├── visdom_config/
│   │   ├── __init__.py
│   │   └── visual_visdom.py
│   └── warmup_config/
│       ├── __init__.py
│       └── warmup.py
├── dataload/
│   ├── __init__.py
│   ├── dataset_2d.py
│   └── dataset_3d.py
├── loss/
│   ├── __init__.py
│   └── loss_function.py
├── models/
│   ├── __init__.py
│   ├── getnetwork.py
│   ├── networks_2d/
│   │   ├── __init__.py
│   │   ├── aerial_lanenet.py
│   │   ├── hrnet.py
│   │   ├── mwcnn.py
│   │   ├── resunet.py
│   │   ├── resunet_plusplus.py
│   │   ├── swinunet.py
│   │   ├── u2net.py
│   │   ├── unet.py
│   │   ├── unet_3plus.py
│   │   ├── unet_cct.py
│   │   ├── unet_plusplus.py
│   │   ├── unet_urpc.py
│   │   ├── wavesnet.py
│   │   ├── wds.py
│   │   └── xnet.py
│   └── networks_3d/
│       ├── __init__.py
│       ├── conresnet.py
│       ├── cotr.py
│       ├── dmfnet.py
│       ├── espnet3d.py
│       ├── res_unet3d.py
│       ├── transbts.py
│       ├── unet3d.py
│       ├── unet3d_cct.py
│       ├── unet3d_dtc.py
│       ├── unet3d_urpc.py
│       ├── unetr.py
│       ├── vnet.py
│       ├── vnet_cct.py
│       ├── vnet_dtc.py
│       └── xnet3d.py
├── requirements.txt
├── test.py
├── test_3d.py
├── test_ConResNet.py
├── test_DTC.py
├── test_xnet.py
├── test_xnet3d.py
├── tools/
│   ├── Atrial/
│   │   ├── __init__.py
│   │   ├── postprocess.py
│   │   └── preprocess.py
│   ├── LiTS/
│   │   ├── __init__.py
│   │   ├── postprocess.py
│   │   ├── preprocess.py
│   │   └── split_train_val.py
│   ├── __init__.py
│   ├── eval.py
│   ├── mask2sdf.py
│   ├── res_image_mask.py
│   ├── wavelet2D.py
│   └── wavelet3D.py
├── train_semi_CCT.py
├── train_semi_CCT_3d.py
├── train_semi_CPS.py
├── train_semi_CPS_3d.py
├── train_semi_CT.py
├── train_semi_CT_3d.py
├── train_semi_DTC.py
├── train_semi_EM.py
├── train_semi_EM_3d.py
├── train_semi_MT.py
├── train_semi_MT_3d.py
├── train_semi_UAMT.py
├── train_semi_UAMT_3d.py
├── train_semi_URPC.py
├── train_semi_URPC_3d.py
├── train_semi_XNet.py
├── train_semi_XNet3d.py
├── train_sup.py
├── train_sup_3d.py
├── train_sup_ConResNet.py
├── train_sup_XNet.py
├── train_sup_XNet3d.py
├── train_sup_XNet_sb.py
├── train_sup_alnet.py
└── train_sup_wds.py
Download .txt
SYMBOL INDEX (1024 symbols across 75 files)

FILE: config/augmentation/online_aug.py
  function data_transform_2d (line 6) | def data_transform_2d():
  function data_normalize_2d (line 30) | def data_normalize_2d(mean, std):
  function data_transform_aerial_lanenet (line 39) | def data_transform_aerial_lanenet(H, W):
  function data_transform_3d (line 47) | def data_transform_3d(normalization):

FILE: config/dataset_config/dataset_cfg.py
  function dataset_cfg (line 4) | def dataset_cfg(dataet_name):

FILE: config/eval_config/eval.py
  function evaluate (line 7) | def evaluate(y_scores, y_true, interval=0.02):
  function evaluate_multi (line 36) | def evaluate_multi(y_scores, y_true):

FILE: config/ramps/ramps.py
  function sigmoid_rampup (line 4) | def sigmoid_rampup(current, rampup_length):
  function linear_rampup (line 14) | def linear_rampup(current, rampup_length):
  function cosine_rampdown (line 23) | def cosine_rampdown(current, rampdown_length):

FILE: config/train_test_config/train_test_config.py
  function print_train_loss_sup (line 8) | def print_train_loss_sup(train_loss, num_batches, print_num, print_num_m...
  function print_train_loss_MT (line 15) | def print_train_loss_MT(train_loss_sup_1, train_loss_cps, train_loss, nu...
  function print_train_loss_ConResNet (line 25) | def print_train_loss_ConResNet(train_loss_seg, train_loss_res, train_los...
  function print_train_loss_EM (line 36) | def print_train_loss_EM(train_loss_sup_1, train_loss_cps, train_loss, nu...
  function print_train_loss_XNet (line 48) | def print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss...
  function print_val_loss_sup (line 59) | def print_val_loss_sup(val_loss, num_batches, print_num, print_num_minus):
  function print_val_loss (line 66) | def print_val_loss(val_loss_sup_1, val_loss_sup_2, num_batches, print_nu...
  function print_val_loss_ConResNet (line 74) | def print_val_loss_ConResNet(val_loss_seg, val_loss_res, num_batches, pr...
  function print_train_eval_sup (line 82) | def print_train_eval_sup(num_classes, score_list_train, mask_list_train,...
  function print_train_eval_XNet (line 103) | def print_train_eval_XNet(num_classes, score_list_train1, score_list_tra...
  function print_val_eval_sup (line 126) | def print_val_eval_sup(num_classes, score_list_val, mask_list_val, print...
  function print_val_eval (line 143) | def print_val_eval(num_classes, score_list_val1, score_list_val2, mask_l...
  function save_val_best_sup_2d (line 164) | def save_val_best_sup_2d(num_classes, best_list, model, score_list_val, ...
  function save_val_best_sup_3d (line 200) | def save_val_best_sup_3d(num_classes, best_list, model, score_list_val, ...
  function save_val_best_2d (line 216) | def save_val_best_2d(num_classes, best_model, best_list, best_result, mo...
  function save_val_best_3d (line 278) | def save_val_best_3d(num_classes, best_model, best_list, best_result, mo...
  function draw_pred_sup (line 310) | def draw_pred_sup(num_classes, mask_train_sup, mask_val, pred_train_sup,...
  function draw_pred_XNet (line 335) | def draw_pred_XNet(num_classes, mask_train, mask_val, pred_train_sup1, p...
  function draw_pred_MT (line 374) | def draw_pred_MT(num_classes, mask_train, mask_val, pred_train_sup1, out...
  function print_best_sup (line 407) | def print_best_sup(num_classes, best_val_list, print_num):
  function print_best (line 419) | def print_best(num_classes, best_val_list, best_model, best_result, path...
  function print_test_eval (line 439) | def print_test_eval(num_classes, score_list_test, mask_list_test, print_...
  function save_test_2d (line 456) | def save_test_2d(num_classes, score_list_test, name_list_test, threshold...
  function save_test_3d (line 482) | def save_test_3d(num_classes, score_test, name_test, threshold, path_seg...

FILE: config/visdom_config/visual_visdom.py
  function visdom_initialization_sup (line 4) | def visdom_initialization_sup(env, port):
  function visualization_sup (line 12) | def visualization_sup(vis, epoch, train_loss, train_m_jc, val_loss, val_...
  function visual_image_sup (line 18) | def visual_image_sup(vis, mask_train, pred_train, mask_val, pred_val):
  function visdom_initialization_XNet (line 26) | def visdom_initialization_XNet(env, port):
  function visualization_XNet (line 34) | def visualization_XNet(vis, epoch, train_loss, train_loss_sup1, train_lo...
  function visual_image_XNet (line 40) | def visual_image_XNet(vis, mask_train, pred_train1, pred_train2, mask_va...
  function visdom_initialization_MT (line 51) | def visdom_initialization_MT(env, port):
  function visualization_MT (line 59) | def visualization_MT(vis, epoch, train_loss, train_loss_sup1, train_loss...
  function visual_image_MT (line 65) | def visual_image_MT(vis, mask_train, pred_train1, mask_val, pred_val1, p...
  function visdom_initialization_EM (line 74) | def visdom_initialization_EM(env, port):
  function visualization_EM (line 82) | def visualization_EM(vis, epoch, train_loss, train_loss_sup1, train_loss...
  function visdom_initialization_ConResNet (line 89) | def visdom_initialization_ConResNet(env, port):
  function visualization_ConResNet (line 97) | def visualization_ConResNet(vis, epoch, train_loss, train_loss_seg, trai...

FILE: config/warmup_config/warmup.py
  class GradualWarmupScheduler (line 5) | class GradualWarmupScheduler(_LRScheduler):
    method __init__ (line 15) | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler...
    method get_lr (line 24) | def get_lr(self):
    method step_ReduceLROnPlateau (line 38) | def step_ReduceLROnPlateau(self, metrics, epoch=None):
    method step (line 52) | def step(self, epoch=None, metrics=None):

FILE: dataload/dataset_2d.py
  class dataset_itn (line 9) | class dataset_itn(Dataset):
    method __init__ (line 10) | def __init__(self, data_dir, input1, augmentation_1, normalize_1, sup=...
    method __getitem__ (line 57) | def __getitem__(self, index):
    method __len__ (line 89) | def __len__(self):
  function imagefloder_itn (line 93) | def imagefloder_itn(data_dir, input1, data_transform_1, data_normalize_1...
  class dataset_iitnn (line 105) | class dataset_iitnn(Dataset):
    method __init__ (line 106) | def __init__(self, data_dir, input1, input2, augmentation1, normalize_...
    method __getitem__ (line 165) | def __getitem__(self, index):
    method __len__ (line 210) | def __len__(self):
  function imagefloder_iitnn (line 214) | def imagefloder_iitnn(data_dir, input1, input2, data_transform_1, data_n...
  class dataset_wds (line 227) | class dataset_wds(Dataset):
    method __init__ (line 228) | def __init__(self, data_dir, augmentation1, normalize_LL, normalize_LH...
    method __getitem__ (line 268) | def __getitem__(self, index):
    method __len__ (line 315) | def __len__(self):
  function imagefloder_wds (line 319) | def imagefloder_wds(data_dir, data_transform_1, data_normalize_LL, data_...
  class dataset_aerial_lanenet (line 330) | class dataset_aerial_lanenet(Dataset):
    method __init__ (line 331) | def __init__(self, data_dir, augmentation1, normalize_1, normalize_l1,...
    method __getitem__ (line 357) | def __getitem__(self, index):
    method __len__ (line 395) | def __len__(self):
  function imagefloder_aerial_lanenet (line 399) | def imagefloder_aerial_lanenet(data_dir, data_transform, data_normalize,...

FILE: dataload/dataset_3d.py
  class dataset_it (line 12) | class dataset_it(Dataset):
    method __init__ (line 13) | def __init__(self, data_dir, input1, transform_1, queue_length=20, sam...
  class dataset_it_dtc (line 60) | class dataset_it_dtc(Dataset):
    method __init__ (line 61) | def __init__(self, data_dir, input1, num_classes, transform_1, queue_l...
  class dataset_iit (line 124) | class dataset_iit(Dataset):
    method __init__ (line 125) | def __init__(self, data_dir, input1, input2, transform_1, queue_length...
  class dataset_iit_conresnet (line 175) | class dataset_iit_conresnet(Dataset):
    method __init__ (line 176) | def __init__(self, data_dir, input1, input2, transform_1, queue_length...

FILE: loss/loss_function.py
  class MixSoftmaxCrossEntropyLoss (line 9) | class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
    method __init__ (line 10) | def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
    method _aux_forward (line 15) | def _aux_forward(self, output, target, **kwargs):
    method forward (line 24) | def forward(self, output, target):
  class BinaryDiceLoss (line 32) | class BinaryDiceLoss(nn.Module):
    method __init__ (line 47) | def __init__(self, smooth=1, p=2, reduction='mean'):
    method forward (line 53) | def forward(self, predict, target, valid_mask):
  class DiceLoss (line 74) | class DiceLoss(nn.Module):
    method __init__ (line 77) | def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_inde...
    method _base_forward (line 85) | def _base_forward(self, predict, target, valid_mask):
    method _aux_forward (line 102) | def _aux_forward(self, output, target, **kwargs):
    method forward (line 112) | def forward(self, output, target):
  function softmax_mse_loss (line 123) | def softmax_mse_loss(input_logits, target_logits, sigmoid=False):
  function entropy_loss (line 142) | def entropy_loss(p, C=2):
  class BCELossBoud (line 149) | class BCELossBoud(nn.Module):
    method __init__ (line 150) | def __init__(self, num_classes, weight=None, ignore_index=None, **kwar...
    method weighted_BCE_cross_entropy (line 158) | def weighted_BCE_cross_entropy(self, output, target, weights = None):
    method forward (line 168) | def forward(self, predict, target):
  class CustomKLLoss (line 187) | class CustomKLLoss(_Loss):
    method __init__ (line 193) | def __init__(self, *args, **kwargs):
    method forward (line 196) | def forward(self, mean, std):
  function segmentation_loss (line 201) | def segmentation_loss(loss='CE', aux=False, **kwargs):

FILE: models/getnetwork.py
  function get_network (line 5) | def get_network(network, in_channels, num_classes, **kwargs):

FILE: models/networks_2d/aerial_lanenet.py
  class basic_block (line 10) | class basic_block(nn.Module):
    method __init__ (line 11) | def __init__(self, ch_in, ch_out):
    method forward (line 16) | def forward(self, x):
  class Aerial_LaneNet (line 20) | class Aerial_LaneNet(nn.Module):
    method __init__ (line 21) | def __init__(self, in_channels, num_classes):
    method forward (line 96) | def forward(self, x, x_wavelet_1, x_wavelet_2, x_wavelet_3, x_wavelet_4):

FILE: models/networks_2d/hrnet.py
  function load_url (line 37) | def load_url(url, model_dir='./pretrained', map_location=None):
  function conv3x3 (line 47) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 53) | class BasicBlock(nn.Module):
    method __init__ (line 56) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 66) | def forward(self, x):
  class Bottleneck (line 85) | class Bottleneck(nn.Module):
    method __init__ (line 88) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 103) | def forward(self, x):
  class HighResolutionModule (line 126) | class HighResolutionModule(nn.Module):
    method __init__ (line 127) | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
    method _check_branches (line 144) | def _check_branches(self, num_branches, blocks, num_blocks,
    method _make_one_branch (line 164) | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
    method _make_branches (line 188) | def _make_branches(self, num_branches, block, num_blocks, num_channels):
    method _make_fuse_layers (line 197) | def _make_fuse_layers(self):
    method get_num_inchannels (line 243) | def get_num_inchannels(self):
    method forward (line 246) | def forward(self, x):
  class HighResolutionNet (line 279) | class HighResolutionNet(nn.Module):
    method __init__ (line 281) | def __init__(self, in_channels, extra, num_classes,**kwargs):
    method _make_transition_layer (line 349) | def _make_transition_layer(
    method _make_layer (line 385) | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
    method _make_stage (line 402) | def _make_stage(self, layer_config, num_inchannels,
    method forward (line 431) | def forward(self, x):
    method init_weights (line 485) | def init_weights(self, pretrained='', ):
  function init_weights (line 553) | def init_weights(net, init_type='normal', gain=0.02):
  function hrnet18 (line 592) | def hrnet18(in_channels, num_classes):
  function hrnet32 (line 597) | def hrnet32(in_channels, num_classes):
  function hrnet48 (line 602) | def hrnet48(in_channels, num_classes):
  function hrnet64 (line 607) | def hrnet64(in_channels, num_classes):

FILE: models/networks_2d/mwcnn.py
  function default_conv (line 9) | def default_conv(in_channels, out_channels, kernel_size, bias=True, dila...
  function default_conv1 (line 15) | def default_conv1(in_channels, out_channels, kernel_size, bias=True, gro...
  function channel_shuffle (line 23) | def channel_shuffle(x, groups):
  function pixel_down_shuffle (line 40) | def pixel_down_shuffle(x, downsacale_factor):
  function sp_init (line 54) | def sp_init(x):
  function dwt_init (line 65) | def dwt_init(x):
  function iwt_init (line 80) | def iwt_init(x):
  class Channel_Shuffle (line 102) | class Channel_Shuffle(nn.Module):
    method __init__ (line 103) | def __init__(self, conv_groups):
    method forward (line 108) | def forward(self, x):
  class SP (line 112) | class SP(nn.Module):
    method __init__ (line 113) | def __init__(self):
    method forward (line 117) | def forward(self, x):
  class Pixel_Down_Shuffle (line 121) | class Pixel_Down_Shuffle(nn.Module):
    method __init__ (line 122) | def __init__(self):
    method forward (line 126) | def forward(self, x):
  class DWT (line 130) | class DWT(nn.Module):
    method __init__ (line 131) | def __init__(self):
    method forward (line 135) | def forward(self, x):
  class IWT (line 139) | class IWT(nn.Module):
    method __init__ (line 140) | def __init__(self):
    method forward (line 144) | def forward(self, x):
  class MeanShift (line 148) | class MeanShift(nn.Conv2d):
    method __init__ (line 149) | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
  class MeanShift2 (line 162) | class MeanShift2(nn.Conv2d):
    method __init__ (line 163) | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
  class BasicBlock (line 175) | class BasicBlock(nn.Sequential):
    method __init__ (line 176) | def __init__(
  class BBlock (line 189) | class BBlock(nn.Module):
    method __init__ (line 190) | def __init__(
    method forward (line 202) | def forward(self, x):
  class DBlock_com (line 207) | class DBlock_com(nn.Module):
    method __init__ (line 208) | def __init__(
    method forward (line 225) | def forward(self, x):
  class DBlock_inv (line 230) | class DBlock_inv(nn.Module):
    method __init__ (line 231) | def __init__(
    method forward (line 248) | def forward(self, x):
  class DBlock_com1 (line 253) | class DBlock_com1(nn.Module):
    method __init__ (line 254) | def __init__(
    method forward (line 271) | def forward(self, x):
  class DBlock_inv1 (line 276) | class DBlock_inv1(nn.Module):
    method __init__ (line 277) | def __init__(
    method forward (line 294) | def forward(self, x):
  class DBlock_com2 (line 299) | class DBlock_com2(nn.Module):
    method __init__ (line 300) | def __init__(
    method forward (line 317) | def forward(self, x):
  class DBlock_inv2 (line 322) | class DBlock_inv2(nn.Module):
    method __init__ (line 323) | def __init__(
    method forward (line 340) | def forward(self, x):
  class ShuffleBlock (line 345) | class ShuffleBlock(nn.Module):
    method __init__ (line 346) | def __init__(
    method forward (line 359) | def forward(self, x):
  class DWBlock (line 364) | class DWBlock(nn.Module):
    method __init__ (line 365) | def __init__(
    method forward (line 382) | def forward(self, x):
  class ResBlock (line 387) | class ResBlock(nn.Module):
    method __init__ (line 388) | def __init__(
    method forward (line 402) | def forward(self, x):
  class Block (line 409) | class Block(nn.Module):
    method __init__ (line 410) | def __init__(
    method forward (line 424) | def forward(self, x):
  class Upsampler (line 431) | class Upsampler(nn.Sequential):
    method __init__ (line 432) | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
  class MWCNN (line 452) | class MWCNN(nn.Module):
    method __init__ (line 453) | def __init__(self, in_channels, num_classes, conv=default_conv):
    method forward (line 501) | def forward(self, x):
    method set_scale (line 512) | def set_scale(self, scale_idx):
  function mwcnn (line 516) | def mwcnn(in_channels, num_classes):

FILE: models/networks_2d/resunet.py
  function init_weights (line 5) | def init_weights(net, init_type='normal', gain=0.02):
  class ResidualConv (line 28) | class ResidualConv(nn.Module):
    method __init__ (line 29) | def __init__(self, input_dim, output_dim, stride, padding):
    method forward (line 47) | def forward(self, x):
  class Upsample (line 52) | class Upsample(nn.Module):
    method __init__ (line 53) | def __init__(self, input_dim, output_dim, kernel, stride):
    method forward (line 60) | def forward(self, x):
  class ResUnet (line 64) | class ResUnet(nn.Module):
    method __init__ (line 65) | def __init__(self, in_channels, num_classes, filters=[64, 128, 256, 51...
    method forward (line 94) | def forward(self, x):
  function res_unet (line 122) | def res_unet(in_channels, num_classes):

FILE: models/networks_2d/resunet_plusplus.py
  function init_weights (line 5) | def init_weights(net, init_type='normal', gain=0.02):
  class ResidualConv (line 28) | class ResidualConv(nn.Module):
    method __init__ (line 29) | def __init__(self, input_dim, output_dim, stride, padding):
    method forward (line 47) | def forward(self, x):
  class Upsample (line 52) | class Upsample(nn.Module):
    method __init__ (line 53) | def __init__(self, input_dim, output_dim, kernel, stride):
    method forward (line 60) | def forward(self, x):
  class Squeeze_Excite_Block (line 63) | class Squeeze_Excite_Block(nn.Module):
    method __init__ (line 64) | def __init__(self, channel, reduction=16):
    method forward (line 74) | def forward(self, x):
  class ASPP (line 80) | class ASPP(nn.Module):
    method __init__ (line 81) | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
    method forward (line 109) | def forward(self, x):
    method _init_weights (line 116) | def _init_weights(self):
  class Upsample_ (line 124) | class Upsample_(nn.Module):
    method __init__ (line 125) | def __init__(self, scale=2):
    method forward (line 130) | def forward(self, x):
  class AttentionBlock (line 133) | class AttentionBlock(nn.Module):
    method __init__ (line 134) | def __init__(self, input_encoder, input_decoder, output_dim):
    method forward (line 156) | def forward(self, x1, x2):
  class ResUnetPlusPlus (line 163) | class ResUnetPlusPlus(nn.Module):
    method __init__ (line 164) | def __init__(self, in_channels, num_classes, filters=[32, 64, 128, 256...
    method forward (line 207) | def forward(self, x):
  function res_unet_plusplus (line 241) | def res_unet_plusplus(in_channels, num_classes):

FILE: models/networks_2d/swinunet.py
  class Mlp (line 25) | class Mlp(nn.Module):
    method __init__ (line 26) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 35) | def forward(self, x):
  function window_partition (line 44) | def window_partition(x, window_size):
  function window_reverse (line 59) | def window_reverse(windows, window_size, H, W):
  class WindowAttention (line 76) | class WindowAttention(nn.Module):
    method __init__ (line 89) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
    method forward (line 127) | def forward(self, x, mask=None):
    method extra_repr (line 164) | def extra_repr(self) -> str:
    method flops (line 167) | def flops(self, N):
  class SwinTransformerBlock (line 181) | class SwinTransformerBlock(nn.Module):
    method __init__ (line 199) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
    method forward (line 255) | def forward(self, x):
    method extra_repr (line 301) | def extra_repr(self) -> str:
    method flops (line 305) | def flops(self):
  class PatchMerging (line 320) | class PatchMerging(nn.Module):
    method __init__ (line 328) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
    method forward (line 335) | def forward(self, x):
    method extra_repr (line 358) | def extra_repr(self) -> str:
    method flops (line 361) | def flops(self):
  class PatchExpand (line 368) | class PatchExpand(nn.Module):
    method __init__ (line 369) | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.L...
    method forward (line 377) | def forward(self, x):
  class FinalPatchExpand_X4 (line 395) | class FinalPatchExpand_X4(nn.Module):
    method __init__ (line 396) | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.L...
    method forward (line 405) | def forward(self, x):
  class BasicLayer (line 422) | class BasicLayer(nn.Module):
    method __init__ (line 441) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
    method forward (line 472) | def forward(self, x):
    method extra_repr (line 482) | def extra_repr(self) -> str:
    method flops (line 485) | def flops(self):
  class BasicLayer_up (line 494) | class BasicLayer_up(nn.Module):
    method __init__ (line 513) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
    method forward (line 544) | def forward(self, x):
  class PatchEmbed (line 555) | class PatchEmbed(nn.Module):
    method __init__ (line 565) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
    method forward (line 586) | def forward(self, x):
    method flops (line 596) | def flops(self):
  class SwinTransformerSys (line 605) | class SwinTransformerSys(nn.Module):
    method __init__ (line 630) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
    method _init_weights (line 733) | def _init_weights(self, m):
    method no_weight_decay (line 743) | def no_weight_decay(self):
    method no_weight_decay_keywords (line 747) | def no_weight_decay_keywords(self):
    method forward_features (line 751) | def forward_features(self, x):
    method forward_up_features (line 767) | def forward_up_features(self, x, x_downsample):
    method up_x4 (line 780) | def up_x4(self, x):
    method forward (line 793) | def forward(self, x):
    method flops (line 800) | def flops(self):
  class SwinUnet (line 811) | class SwinUnet(nn.Module):
    method __init__ (line 812) | def __init__(self, num_classes, img_size, zero_head=False, vis=False):
    method forward (line 833) | def forward(self, x):
    method load_from (line 839) | def load_from(self, config):
  function swinunet (line 877) | def swinunet(num_classes, img_size):

FILE: models/networks_2d/u2net.py
  function init_weights (line 6) | def init_weights(net, init_type='normal', gain=0.02):
  class REBNCONV (line 30) | class REBNCONV(nn.Module):
    method __init__ (line 31) | def __init__(self,in_ch=3,out_ch=3,dirate=1):
    method forward (line 38) | def forward(self,x):
  function _upsample_like (line 46) | def _upsample_like(src,tar):
  class RSU7 (line 54) | class RSU7(nn.Module):#UNet07DRES(nn.Module):
    method __init__ (line 56) | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
    method forward (line 87) | def forward(self,x):
  class RSU6 (line 131) | class RSU6(nn.Module):#UNet06DRES(nn.Module):
    method __init__ (line 133) | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
    method forward (line 160) | def forward(self,x):
  class RSU5 (line 200) | class RSU5(nn.Module):#UNet05DRES(nn.Module):
    method __init__ (line 202) | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
    method forward (line 225) | def forward(self,x):
  class RSU4 (line 258) | class RSU4(nn.Module):#UNet04DRES(nn.Module):
    method __init__ (line 260) | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
    method forward (line 279) | def forward(self,x):
  class RSU4F (line 306) | class RSU4F(nn.Module):#UNet04FRES(nn.Module):
    method __init__ (line 308) | def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
    method forward (line 323) | def forward(self,x):
  class U2NET (line 343) | class U2NET(nn.Module):
    method __init__ (line 345) | def __init__(self,in_ch=3,out_ch=1):
    method forward (line 381) | def forward(self,x):
  class U2NETP (line 448) | class U2NETP(nn.Module):
    method __init__ (line 450) | def __init__(self,in_ch=3,out_ch=1):
    method forward (line 486) | def forward(self,x):
  function u2net (line 552) | def u2net(in_channels, num_classes):
  function u2net_small (line 557) | def u2net_small(in_channels, num_classes):

FILE: models/networks_2d/unet.py
  function init_weights (line 7) | def init_weights(net, init_type='normal', gain=0.02):
  class conv_block (line 31) | class conv_block(nn.Module):
    method __init__ (line 32) | def __init__(self, ch_in, ch_out):
    method forward (line 43) | def forward(self, x):
  class up_conv (line 48) | class up_conv(nn.Module):
    method __init__ (line 49) | def __init__(self, ch_in, ch_out):
    method forward (line 58) | def forward(self, x):
  class Recurrent_block (line 63) | class Recurrent_block(nn.Module):
    method __init__ (line 64) | def __init__(self, ch_out, t=2):
    method forward (line 74) | def forward(self, x):
  class RRCNN_block (line 84) | class RRCNN_block(nn.Module):
    method __init__ (line 85) | def __init__(self, ch_in, ch_out, t=2):
    method forward (line 93) | def forward(self, x):
  class single_conv (line 99) | class single_conv(nn.Module):
    method __init__ (line 100) | def __init__(self, ch_in, ch_out):
    method forward (line 108) | def forward(self, x):
  class Attention_block (line 113) | class Attention_block(nn.Module):
    method __init__ (line 114) | def __init__(self, F_g, F_l, F_int):
    method forward (line 134) | def forward(self, g, x):
  class U_Net (line 143) | class U_Net(nn.Module):
    method __init__ (line 144) | def __init__(self, in_channels=3, num_classes=1):
    method forward (line 169) | def forward(self, x):
  class R2U_Net (line 210) | class R2U_Net(nn.Module):
    method __init__ (line 211) | def __init__(self, in_channels=3, num_classes=1, t=2):
    method forward (line 241) | def forward(self, x):
  class AttU_Net (line 281) | class AttU_Net(nn.Module):
    method __init__ (line 282) | def __init__(self, in_channels=3, num_classes=1):
    method forward (line 311) | def forward(self, x):
  class R2AttU_Net (line 355) | class R2AttU_Net(nn.Module):
    method __init__ (line 356) | def __init__(self, in_channels=3, num_classes=1, t=2):
    method forward (line 390) | def forward(self, x):
  function unet (line 434) | def unet(in_channels, num_classes):
  function r2_unet (line 439) | def r2_unet(in_channels, num_classes):
  function attention_unet (line 444) | def attention_unet(in_channels, num_classes):
  function r2_attention_unet (line 449) | def r2_attention_unet(in_channels, num_classes):

FILE: models/networks_2d/unet_3plus.py
  function weights_init_normal (line 8) | def weights_init_normal(m):
  function weights_init_xavier (line 20) | def weights_init_xavier(m):
  function weights_init_kaiming (line 32) | def weights_init_kaiming(m):
  function weights_init_orthogonal (line 44) | def weights_init_orthogonal(m):
  function init_weights (line 56) | def init_weights(net, init_type='normal'):
  class unetConv2 (line 70) | class unetConv2(nn.Module):
    method __init__ (line 71) | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=...
    method forward (line 98) | def forward(self, inputs):
  class UNet_3Plus (line 110) | class UNet_3Plus(nn.Module):
    method __init__ (line 112) | def __init__(self, in_channels, num_classes):
    method forward (line 294) | def forward(self, inputs):
  class UNet_3Plus_DeepSup (line 352) | class UNet_3Plus_DeepSup(nn.Module):
    method __init__ (line 353) | def __init__(self, in_channels=3, num_classes=1, feature_scale=4, is_d...
    method forward (line 543) | def forward(self, inputs):
  class UNet_3Plus_DeepSup_CGM (line 613) | class UNet_3Plus_DeepSup_CGM(nn.Module):
    method __init__ (line 615) | def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_dec...
    method dotProduct (line 811) | def dotProduct(self, seg, cls):
    method forward (line 818) | def forward(self, inputs):
  function unet_3plus (line 891) | def unet_3plus(in_channels, num_classes):
  function unet_3plus_ds (line 895) | def unet_3plus_ds(in_channels, num_classes):
  function unet_3plus_ds_cgm (line 899) | def unet_3plus_ds_cgm(in_channels, num_classes):

FILE: models/networks_2d/unet_cct.py
  function init_weights (line 7) | def init_weights(net, init_type='normal', gain=0.02):
  class ConvBlock (line 30) | class ConvBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, in_channels, out_channels, dropout_p):
    method forward (line 45) | def forward(self, x):
  class DownBlock (line 49) | class DownBlock(nn.Module):
    method __init__ (line 52) | def __init__(self, in_channels, out_channels, dropout_p):
    method forward (line 60) | def forward(self, x):
  class UpBlock (line 64) | class UpBlock(nn.Module):
    method __init__ (line 67) | def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
    method forward (line 80) | def forward(self, x1, x2):
  class Encoder (line 87) | class Encoder(nn.Module):
    method __init__ (line 88) | def __init__(self, params):
    method forward (line 108) | def forward(self, x):
  class Decoder (line 116) | class Decoder(nn.Module):
    method __init__ (line 117) | def __init__(self, params):
    method forward (line 138) | def forward(self, feature):
  function Dropout (line 153) | def Dropout(x, p=0.3):
  function FeatureDropout (line 158) | def FeatureDropout(x):
  class FeatureNoise (line 169) | class FeatureNoise(nn.Module):
    method __init__ (line 170) | def __init__(self, uniform_range=0.3):
    method feature_based_noise (line 174) | def feature_based_noise(self, x):
    method forward (line 180) | def forward(self, x):
  class UNet_CCT (line 184) | class UNet_CCT(nn.Module):
    method __init__ (line 185) | def __init__(self, in_chns, class_num):
    method forward (line 200) | def forward(self, x):
  function unet_cct (line 211) | def unet_cct(in_channels, num_classes):

FILE: models/networks_2d/unet_plusplus.py
  function init_weights (line 5) | def init_weights(net, init_type='normal', gain=0.02):
  class VGGBlock (line 29) | class VGGBlock(nn.Module):
    method __init__ (line 30) | def __init__(self, in_channels, middle_channels, out_channels):
    method forward (line 38) | def forward(self, x):
  class NestedUNet (line 50) | class NestedUNet(nn.Module):
    method __init__ (line 51) | def __init__(self, num_classes, input_channels=3, deep_supervision=Fal...
    method forward (line 90) | def forward(self, input):
  function unet_plusplus (line 129) | def unet_plusplus(in_channels, num_classes):

FILE: models/networks_2d/unet_urpc.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class ConvBlock (line 31) | class ConvBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, in_channels, out_channels, dropout_p):
    method forward (line 46) | def forward(self, x):
  class Encoder (line 49) | class Encoder(nn.Module):
    method __init__ (line 50) | def __init__(self, params):
    method forward (line 70) | def forward(self, x):
  class DownBlock (line 78) | class DownBlock(nn.Module):
    method __init__ (line 81) | def __init__(self, in_channels, out_channels, dropout_p):
    method forward (line 89) | def forward(self, x):
  class UpBlock (line 93) | class UpBlock(nn.Module):
    method __init__ (line 96) | def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
    method forward (line 109) | def forward(self, x1, x2):
  class FeatureNoise (line 116) | class FeatureNoise(nn.Module):
    method __init__ (line 117) | def __init__(self, uniform_range=0.3):
    method feature_based_noise (line 121) | def feature_based_noise(self, x):
    method forward (line 127) | def forward(self, x):
  function Dropout (line 131) | def Dropout(x, p=0.3):
  function FeatureDropout (line 136) | def FeatureDropout(x):
  class Decoder_URPC (line 146) | class Decoder_URPC(nn.Module):
    method __init__ (line 147) | def __init__(self, params):
    method forward (line 177) | def forward(self, feature, shape):
  class UNet_URPC (line 213) | class UNet_URPC(nn.Module):
    method __init__ (line 214) | def __init__(self, in_chns, class_num):
    method forward (line 226) | def forward(self, x):
  function unet_urpc (line 233) | def unet_urpc(in_channels, num_classes):

FILE: models/networks_2d/wavesnet.py
  class My_DownSampling_SC (line 11) | class My_DownSampling_SC(nn.Module):
    method __init__ (line 12) | def __init__(self, in_channel, out_channel, kernel_size = (1,1), strid...
    method forward (line 16) | def forward(self, input):
  class My_DownSampling_MP (line 20) | class My_DownSampling_MP(nn.Module):
    method __init__ (line 21) | def __init__(self, stride = 2, kernel_size = 2):
    method forward (line 25) | def forward(self, input):
  class My_UpSampling_SC (line 29) | class My_UpSampling_SC(nn.Module):
    method __init__ (line 30) | def __init__(self, in_channel, out_channel, kernel_size = (1,1), strid...
    method forward (line 34) | def forward(self, input, feature_map):
  class My_DownSampling_DWT (line 38) | class My_DownSampling_DWT(nn.Module):
    method __init__ (line 39) | def __init__(self, wavename = 'haar'):
    method forward (line 43) | def forward(self, input):
  class My_UpSampling_IDWT (line 48) | class My_UpSampling_IDWT(nn.Module):
    method __init__ (line 49) | def __init__(self, wavename = 'haar'):
    method forward (line 53) | def forward(self, LL, LH, HL, HH, feature_map):
  class My_Sequential (line 58) | class My_Sequential(Module):
    method __init__ (line 65) | def __init__(self, *args):
    method _get_item_by_idx (line 74) | def _get_item_by_idx(self, iterator, idx):
    method __getitem__ (line 83) | def __getitem__(self, idx):
    method __setitem__ (line 89) | def __setitem__(self, idx, module):
    method __delitem__ (line 93) | def __delitem__(self, idx):
    method __len__ (line 101) | def __len__(self):
    method __dir__ (line 104) | def __dir__(self):
    method forward (line 109) | def forward(self, input):
  class My_Sequential_re (line 123) | class My_Sequential_re(Module):
    method __init__ (line 130) | def __init__(self, *args):
    method _get_item_by_idx (line 140) | def _get_item_by_idx(self, iterator, idx):
    method __getitem__ (line 149) | def __getitem__(self, idx):
    method __setitem__ (line 155) | def __setitem__(self, idx, module):
    method __delitem__ (line 159) | def __delitem__(self, idx):
    method __len__ (line 167) | def __len__(self):
    method __dir__ (line 170) | def __dir__(self):
    method forward (line 175) | def forward(self, *input):
  class DWTFunction_1D (line 207) | class DWTFunction_1D(Function):
    method forward (line 209) | def forward(ctx, input, matrix_Low, matrix_High):
    method backward (line 215) | def backward(ctx, grad_L, grad_H):
  class IDWTFunction_1D (line 221) | class IDWTFunction_1D(Function):
    method forward (line 223) | def forward(ctx, input_L, input_H, matrix_L, matrix_H):
    method backward (line 228) | def backward(ctx, grad_output):
  class DWTFunction_2D (line 235) | class DWTFunction_2D(Function):
    method forward (line 237) | def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, mat...
    method backward (line 247) | def backward(ctx, grad_LL, grad_LH, grad_HL, grad_HH):
  class DWTFunction_2D_tiny (line 255) | class DWTFunction_2D_tiny(Function):
    method forward (line 257) | def forward(ctx, input, matrix_Low_0, matrix_Low_1, matrix_High_0, mat...
    method backward (line 263) | def backward(ctx, grad_LL):
  class IDWTFunction_2D (line 270) | class IDWTFunction_2D(Function):
    method forward (line 272) | def forward(ctx, input_LL, input_LH, input_HL, input_HH,
    method backward (line 280) | def backward(ctx, grad_output):
  class DWTFunction_3D (line 291) | class DWTFunction_3D(Function):
    method forward (line 293) | def forward(ctx, input,
    method backward (line 315) | def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,
  class IDWTFunction_3D (line 328) | class IDWTFunction_3D(Function):
    method forward (line 330) | def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,
    method backward (line 345) | def backward(ctx, grad_output):
  class DWT_1D (line 364) | class DWT_1D(Module):
    method __init__ (line 370) | def __init__(self, wavename):
    method get_matrix (line 384) | def get_matrix(self):
    method forward (line 413) | def forward(self, input):
  class IDWT_1D (line 421) | class IDWT_1D(Module):
    method __init__ (line 427) | def __init__(self, wavename):
    method get_matrix (line 443) | def get_matrix(self):
    method forward (line 472) | def forward(self, L, H):
  class DWT_2D (line 480) | class DWT_2D(Module):
    method __init__ (line 488) | def __init__(self, wavename):
    method get_matrix (line 502) | def get_matrix(self):
    method forward (line 547) | def forward(self, input):
  class DWT_2D_tiny (line 557) | class DWT_2D_tiny(Module):
    method __init__ (line 562) | def __init__(self, wavename):
    method get_matrix (line 576) | def get_matrix(self):
    method forward (line 621) | def forward(self, input):
  class IDWT_2D (line 630) | class IDWT_2D(Module):
    method __init__ (line 638) | def __init__(self, wavename):
    method get_matrix (line 654) | def get_matrix(self):
    method forward (line 698) | def forward(self, LL, LH, HL, HH):
  class DWT_3D (line 707) | class DWT_3D(Module):
    method __init__ (line 719) | def __init__(self, wavename):
    method get_matrix (line 733) | def get_matrix(self):
    method forward (line 786) | def forward(self, input):
  class IDWT_3D (line 797) | class IDWT_3D(Module):
    method __init__ (line 809) | def __init__(self, wavename):
    method get_matrix (line 825) | def get_matrix(self):
    method forward (line 878) | def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):
  class SegNet_VGG (line 918) | class SegNet_VGG(nn.Module):
    method __init__ (line 919) | def __init__(self, features, num_classes = 21, init_weights = True, wa...
    method forward (line 931) | def forward(self, x):
    method _initialize_weights (line 938) | def _initialize_weights(self):
    method __str__ (line 955) | def __str__(self):
  class WSegNet_VGG (line 959) | class WSegNet_VGG(nn.Module):
    method __init__ (line 960) | def __init__(self, features, num_classes, init_weights = True, wavenam...
    method forward (line 972) | def forward(self, x):
    method _initialize_weights (line 979) | def _initialize_weights(self):
    method __str__ (line 996) | def __str__(self):
  function make_layers (line 1000) | def make_layers(cfg, batch_norm = False):
  function make_w_layers (line 1037) | def make_w_layers(cfg, in_channels, batch_norm = False, wavename = 'haar'):
  function segnet_vgg11 (line 1080) | def segnet_vgg11(pretrained = False, **kwargs):
  function segnet_vgg11_bn (line 1091) | def segnet_vgg11_bn(pretrained=False, **kwargs):
  function segnet_vgg13 (line 1102) | def segnet_vgg13(pretrained=False, **kwargs):
  function segnet_vgg13_bn (line 1113) | def segnet_vgg13_bn(pretrained=False, **kwargs):
  function segnet_vgg16 (line 1124) | def segnet_vgg16(pretrained=False, **kwargs):
  function segnet_vgg16_bn (line 1135) | def segnet_vgg16_bn(pretrained=False, **kwargs):
  function segnet_vgg19 (line 1146) | def segnet_vgg19(pretrained=False, **kwargs):
  function segnet_vgg19_bn (line 1157) | def segnet_vgg19_bn(pretrained=False, **kwargs):
  function wsegnet_vgg11 (line 1169) | def wsegnet_vgg11(pretrained = False, wavename = 'haar', **kwargs):
  function wsegnet_vgg11_bn (line 1180) | def wsegnet_vgg11_bn(pretrained=False, wavename = 'haar', **kwargs):
  function wsegnet_vgg13 (line 1191) | def wsegnet_vgg13(pretrained=False, wavename = 'haar', **kwargs):
  function wsegnet_vgg13_bn (line 1202) | def wsegnet_vgg13_bn(pretrained=False, wavename = 'haar', **kwargs):
  function wsegnet_vgg16 (line 1213) | def wsegnet_vgg16(pretrained=False, wavename = 'haar', **kwargs):
  function wsegnet_vgg16_bn (line 1224) | def wsegnet_vgg16_bn(in_channels, num_classes, pretrained=False, wavenam...
  function wsegnet_vgg19 (line 1235) | def wsegnet_vgg19(pretrained=False, wavename = 'haar', **kwargs):
  function wsegnet_vgg19_bn (line 1246) | def wsegnet_vgg19_bn(pretrained=False, wavename = 'haar', **kwargs):

FILE: models/networks_2d/wds.py
  class basic_block (line 9) | class basic_block(nn.Module):
    method __init__ (line 10) | def __init__(self, ch_in, ch_out):
    method forward (line 15) | def forward(self, x):
  class WDS (line 19) | class WDS(nn.Module):
    method __init__ (line 20) | def __init__(self, in_channels, num_classes):
    method forward (line 92) | def forward(self, LL, LH, HL, HH):

FILE: models/networks_2d/xnet.py
  function conv1x1 (line 15) | def conv1x1(in_planes, out_planes, stride=1):
  function conv3x3 (line 19) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  class up_conv (line 23) | class up_conv(nn.Module):
    method __init__ (line 24) | def __init__(self, ch_in, ch_out):
    method forward (line 33) | def forward(self, x):
  class down_conv (line 37) | class down_conv(nn.Module):
    method __init__ (line 38) | def __init__(self, ch_in, ch_out):
    method forward (line 45) | def forward(self, x):
  class same_conv (line 49) | class same_conv(nn.Module):
    method __init__ (line 50) | def __init__(self, ch_in, ch_out):
    method forward (line 56) | def forward(self, x):
  class transition_conv (line 60) | class transition_conv(nn.Module):
    method __init__ (line 61) | def __init__(self, ch_in, ch_out):
    method forward (line 67) | def forward(self, x):
  class BasicBlock (line 71) | class BasicBlock(nn.Module):
    method __init__ (line 74) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 92) | def forward(self, x):
  class DoubleBasicBlock (line 110) | class DoubleBasicBlock(nn.Module):
    method __init__ (line 111) | def __init__(self, inplanes, planes, downsample=None):
    method forward (line 119) | def forward(self, x):
  class XNet (line 124) | class XNet(nn.Module):
    method __init__ (line 125) | def __init__(self, in_channels, num_classes):
    method forward (line 226) | def forward(self, input1, input2):
  class XNet_1_1_m (line 322) | class XNet_1_1_m(nn.Module):
    method __init__ (line 323) | def __init__(self, in_channels, num_classes):
    method forward (line 412) | def forward(self, input1, input2):
  class XNet_1_2_m (line 494) | class XNet_1_2_m(nn.Module):
    method __init__ (line 495) | def __init__(self, in_channels, num_classes):
    method forward (line 590) | def forward(self, input1, input2):
  class XNet_2_1_m (line 680) | class XNet_2_1_m(nn.Module):
    method __init__ (line 681) | def __init__(self, in_channels, num_classes):
    method forward (line 776) | def forward(self, input1, input2):
  class XNet_2_3_m (line 865) | class XNet_2_3_m(nn.Module):
    method __init__ (line 866) | def __init__(self, in_channels, num_classes):
    method forward (line 975) | def forward(self, input1, input2):
  class XNet_3_2_m (line 1082) | class XNet_3_2_m(nn.Module):
    method __init__ (line 1083) | def __init__(self, in_channels, num_classes):
    method forward (line 1192) | def forward(self, input1, input2):
  class XNet_3_3_m (line 1300) | class XNet_3_3_m(nn.Module):
    method __init__ (line 1301) | def __init__(self, in_channels, num_classes):
    method forward (line 1418) | def forward(self, input1, input2):
  class XNet_sb (line 1533) | class XNet_sb(nn.Module):
    method __init__ (line 1534) | def __init__(self, in_channels, num_classes):
    method forward (line 1597) | def forward(self, input1):

FILE: models/networks_3d/conresnet.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class Conv3d (line 32) | class Conv3d(nn.Conv3d):
    method __init__ (line 34) | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1...
    method forward (line 37) | def forward(self, x):
  function conv3x3x3 (line 45) | def conv3x3x3(in_planes, out_planes, kernel_size=(3, 3, 3), stride=(1, 1...
  class ConResAtt (line 56) | class ConResAtt(nn.Module):
    method __init__ (line 57) | def __init__(self, in_channels, in_planes, out_planes, kernel_size=(3,...
    method _res (line 95) | def _res(self, x):  # bs, channel, D, W, H
    method forward (line 106) | def forward(self, input):
  class NoBottleneck (line 146) | class NoBottleneck(nn.Module):
    method __init__ (line 147) | def __init__(self, inplanes, planes, stride=(1, 1, 1), dilation=(1, 1,...
    method forward (line 165) | def forward(self, x):
  class ConResNet (line 183) | class ConResNet(nn.Module):
    method __init__ (line 184) | def __init__(self, in_channels, num_classes, shape, block, layers, wei...
    method _make_layer (line 240) | def _make_layer(self, block, inplanes, outplanes, blocks, stride=(1, 1...
    method forward (line 261) | def forward(self, x, x_res):
  function conresnet (line 310) | def conresnet(in_channels, num_classes, **kwargs):

FILE: models/networks_3d/cotr.py
  class PositionEmbeddingSine (line 10) | class PositionEmbeddingSine(nn.Module):
    method __init__ (line 16) | def __init__(self, num_pos_feats=[64, 64, 64], temperature=10000, norm...
    method forward (line 27) | def forward(self, x):
  function build_position_encoding (line 63) | def build_position_encoding(mode, hidden_dim):
  function ms_deform_attn_core_pytorch_3D (line 77) | def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling...
  class MSDeformAttn (line 93) | class MSDeformAttn(nn.Module):
    method __init__ (line 94) | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
    method _reset_parameters (line 121) | def _reset_parameters(self):
    method forward (line 137) | def forward(self, query, reference_points, input_flatten, input_spatia...
  function _get_clones (line 168) | def _get_clones(module, N):
  function _get_activation_fn (line 172) | def _get_activation_fn(activation):
  class DeformableTransformerEncoderLayer (line 182) | class DeformableTransformerEncoderLayer(nn.Module):
    method __init__ (line 183) | def __init__(self,
    method with_pos_embed (line 203) | def with_pos_embed(tensor, pos):
    method forward_ffn (line 206) | def forward_ffn(self, src):
    method forward (line 212) | def forward(self, src, pos, reference_points, spatial_shapes, level_st...
  class DeformableTransformerEncoder (line 224) | class DeformableTransformerEncoder(nn.Module):
    method __init__ (line 225) | def __init__(self, encoder_layer, num_layers):
    method get_reference_points (line 231) | def get_reference_points(spatial_shapes, valid_ratios, device):
    method forward (line 249) | def forward(self, src, spatial_shapes, level_start_index, valid_ratios...
  class DeformableTransformer (line 258) | class DeformableTransformer(nn.Module):
    method __init__ (line 259) | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_fee...
    method _reset_parameters (line 272) | def _reset_parameters(self):
    method get_valid_ratio (line 281) | def get_valid_ratio(self, mask):
    method forward (line 293) | def forward(self, srcs, masks, pos_embeds):
  class Conv3d_wd (line 323) | class Conv3d_wd(nn.Conv3d):
    method __init__ (line 325) | def __init__(self, in_channels, out_channels, kernel_size, stride=(1, ...
    method forward (line 328) | def forward(self, x):
  function conv3x3x3 (line 338) | def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padd...
  function Norm_layer (line 347) | def Norm_layer(norm_cfg, inplanes):
  function Activation_layer (line 360) | def Activation_layer(activation_cfg, inplace=True):
  class ResBlock (line 368) | class ResBlock(nn.Module):
    method __init__ (line 370) | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=...
    method forward (line 378) | def forward(self, x):
  class Backbone (line 392) | class Backbone(nn.Module):
    method __init__ (line 395) | def __init__(self, depth, in_channels=1, norm_cfg='BN', activation_cfg...
    method _make_layer (line 422) | def _make_layer(self, block, planes, blocks, stride=(1, 1, 1), norm_cf...
    method init_weights (line 441) | def init_weights(self):
    method forward (line 451) | def forward(self, x):
  class Conv3dBlock (line 467) | class Conv3dBlock(nn.Module):
    method __init__ (line 468) | def __init__(self, in_channels, out_channels, norm_cfg, activation_cfg...
    method forward (line 474) | def forward(self, x):
  class ResBlock_ (line 481) | class ResBlock_(nn.Module):
    method __init__ (line 483) | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, weight_...
    method forward (line 488) | def forward(self, x):
  class U_ResTran3D (line 497) | class U_ResTran3D(nn.Module):
    method __init__ (line 498) | def __init__(self, in_channels, num_classes, norm_cfg='BN', activation...
    method posi_mask (line 539) | def posi_mask(self, x):
    method forward (line 552) | def forward(self, inputs):
  function cotr (line 587) | def cotr(in_channels, num_classes):

FILE: models/networks_3d/dmfnet.py
  function normalization (line 6) | def normalization(planes, norm='bn'):
  class Conv3d_Block (line 17) | class Conv3d_Block(nn.Module):
    method __init__ (line 18) | def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=No...
    method forward (line 26) | def forward(self, x): # BN + Relu + Conv
  class DilatedConv3DBlock (line 32) | class DilatedConv3DBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1...
    method forward (line 45) | def forward(self, x):
  class MFunit (line 51) | class MFunit(nn.Module):
    method __init__ (line 52) | def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None):
    method forward (line 77) | def forward(self, x):
  class DMFUnit (line 92) | class DMFUnit(nn.Module):
    method __init__ (line 94) | def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=N...
    method forward (line 125) | def forward(self, x):
  class MFNet (line 138) | class MFNet(nn.Module): #
    method __init__ (line 142) | def __init__(self,in_channels, num_classes, n=32, channels=128, groups...
    method forward (line 184) | def forward(self, x):
  class DMFNet (line 208) | class DMFNet(MFNet): # softmax
    method __init__ (line 210) | def __init__(self,in_channels, num_classes, n=32,channels=128, groups=...
  function dmfnet (line 225) | def dmfnet(in_channels, num_classes):

FILE: models/networks_3d/espnet3d.py
  class CBR (line 9) | class CBR(nn.Module):
    method __init__ (line 10) | def __init__(self, nIn, nOut, kSize, stride=1):
    method forward (line 17) | def forward(self, input):
  class CB (line 24) | class CB(nn.Module):
    method __init__ (line 25) | def __init__(self, nIn, nOut, kSize, stride=1):
    method forward (line 31) | def forward(self, input):
  class C (line 37) | class C(nn.Module):
    method __init__ (line 38) | def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
    method forward (line 43) | def forward(self, input):
  class DownSamplerA (line 48) | class DownSamplerA(nn.Module):
    method __init__ (line 49) | def __init__(self, nIn, nOut):
    method forward (line 53) | def forward(self, input):
  class DownSamplerB (line 58) | class DownSamplerB(nn.Module):
    method __init__ (line 59) | def __init__(self, nIn, nOut):
    method forward (line 71) | def forward(self, input):
  class BR (line 89) | class BR(nn.Module):
    method __init__ (line 90) | def __init__(self, nOut):
    method forward (line 95) | def forward(self, input):
  class CDilated (line 101) | class CDilated(nn.Module):
    method __init__ (line 102) | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
    method forward (line 108) | def forward(self, input):
  class InputProjectionA (line 113) | class InputProjectionA(nn.Module):
    method __init__ (line 120) | def __init__(self, samplingTimes):
    method forward (line 130) | def forward(self, input):
  class DilatedParllelResidualBlockB1 (line 140) | class DilatedParllelResidualBlockB1(nn.Module):  # with k=4
    method __init__ (line 141) | def __init__(self, nIn, nOut, stride=1):
    method forward (line 153) | def forward(self, input):
  class ASPBlock (line 170) | class ASPBlock(nn.Module):  # with k=4
    method __init__ (line 171) | def __init__(self, nIn, nOut, stride=1):
    method forward (line 179) | def forward(self, input):
  class UpSampler (line 192) | class UpSampler(nn.Module):
    method __init__ (line 196) | def __init__(self, nIn, nOut):
    method forward (line 200) | def forward(self, inp):
  class PSPDec (line 204) | class PSPDec(nn.Module):
    method __init__ (line 209) | def __init__(self, nIn, nOut, downSize):
    method forward (line 213) | def forward(self, x):
  class ESPNet (line 220) | class ESPNet(nn.Module):
    method __init__ (line 221) | def __init__(self, in_channels, num_classes):
    method forward (line 305) | def forward(self, input1, inp_res=(128, 128, 128), inpSt2=False):
  function espnet3d (line 365) | def espnet3d(in_channels, num_classes):

FILE: models/networks_3d/res_unet3d.py
  function init_weights (line 6) | def init_weights(net, init_type='normal', gain=0.02):
  class UNet (line 30) | class UNet(nn.Module):
    method __init__ (line 35) | def __init__(self, in_channels, n_classes, base_n_filter=8):
    method conv_norm_lrelu (line 88) | def conv_norm_lrelu(self, feat_in, feat_out):
    method norm_lrelu_conv (line 94) | def norm_lrelu_conv(self, feat_in, feat_out):
    method lrelu_conv (line 100) | def lrelu_conv(self, feat_in, feat_out):
    method norm_lrelu_upscale_conv_norm_lrelu (line 105) | def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
    method forward (line 115) | def forward(self, x):
  function res_unet3d (line 212) | def res_unet3d(in_channels, num_classes):

FILE: models/networks_3d/transbts.py
  function init_weights (line 7) | def init_weights(net, init_type='normal', gain=0.02):
  function normalization (line 30) | def normalization(planes, norm='gn'):
  class InitConv (line 41) | class InitConv(nn.Module):
    method __init__ (line 42) | def __init__(self, in_channels=4, out_channels=16, dropout=0.2):
    method forward (line 48) | def forward(self, x):
  class EnBlock (line 54) | class EnBlock(nn.Module):
    method __init__ (line 55) | def __init__(self, in_channels, norm='gn'):
    method forward (line 66) | def forward(self, x):
  class EnDown (line 77) | class EnDown(nn.Module):
    method __init__ (line 78) | def __init__(self, in_channels, out_channels):
    method forward (line 82) | def forward(self, x):
  class Unet (line 87) | class Unet(nn.Module):
    method __init__ (line 88) | def __init__(self, in_channels=4, base_channels=16):
    method forward (line 108) | def forward(self, x):
  class FixedPositionalEncoding (line 129) | class FixedPositionalEncoding(nn.Module):
    method __init__ (line 130) | def __init__(self, embedding_dim, max_length=512):
    method forward (line 141) | def forward(self, x):
  class LearnedPositionalEncoding (line 146) | class LearnedPositionalEncoding(nn.Module):
    method __init__ (line 147) | def __init__(self, max_position_embeddings, embedding_dim):
    method forward (line 152) | def forward(self, x):
  class IntermediateSequential (line 157) | class IntermediateSequential(nn.Sequential):
    method __init__ (line 158) | def __init__(self, *args, return_intermediate=True):
    method forward (line 162) | def forward(self, input):
  class SelfAttention (line 173) | class SelfAttention(nn.Module):
    method __init__ (line 174) | def __init__(
    method forward (line 187) | def forward(self, x):
  class Residual (line 201) | class Residual(nn.Module):
    method __init__ (line 202) | def __init__(self, fn):
    method forward (line 206) | def forward(self, x):
  class PreNorm (line 209) | class PreNorm(nn.Module):
    method __init__ (line 210) | def __init__(self, dim, fn):
    method forward (line 215) | def forward(self, x):
  class PreNormDrop (line 219) | class PreNormDrop(nn.Module):
    method __init__ (line 220) | def __init__(self, dim, dropout_rate, fn):
    method forward (line 226) | def forward(self, x):
  class FeedForward (line 230) | class FeedForward(nn.Module):
    method __init__ (line 231) | def __init__(self, dim, hidden_dim, dropout_rate):
    method forward (line 241) | def forward(self, x):
  class TransformerModel (line 244) | class TransformerModel(nn.Module):
    method __init__ (line 245) | def __init__(self,dim,depth,heads,mlp_dim,dropout_rate=0.1,attn_dropou...
    method forward (line 255) | def forward(self, x):
  class TransformerBTS (line 259) | class TransformerBTS(nn.Module):
    method __init__ (line 260) | def __init__(
    method encode (line 312) | def encode(self, x):
    method forward (line 347) | def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):
    method _reshape_output (line 371) | def _reshape_output(self, x):
  class BTS (line 384) | class BTS(TransformerBTS):
    method __init__ (line 385) | def __init__(self,
    method decode (line 426) | def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, ...
  class EnBlock1 (line 455) | class EnBlock1(nn.Module):
    method __init__ (line 456) | def __init__(self, in_channels, ):
    method forward (line 466) | def forward(self, x):
  class EnBlock2 (line 476) | class EnBlock2(nn.Module):
    method __init__ (line 477) | def __init__(self, in_channels):
    method forward (line 487) | def forward(self, x):
  class DeUp_Cat (line 498) | class DeUp_Cat(nn.Module):
    method __init__ (line 499) | def __init__(self, in_channels, out_channels):
    method forward (line 505) | def forward(self, x, prev):
  class DeBlock (line 513) | class DeBlock(nn.Module):
    method __init__ (line 514) | def __init__(self, in_channels):
    method forward (line 524) | def forward(self, x):
  function transbts (line 536) | def transbts(in_channels, num_classes, **kwargs):

FILE: models/networks_3d/unet3d.py
  function init_weights (line 7) | def init_weights(net, init_type='normal', gain=0.02):
  class UNet3D (line 31) | class UNet3D(nn.Module):
    method __init__ (line 32) | def __init__(self, in_channels=1, out_channels=3, init_features=64):
    method forward (line 72) | def forward(self, x):
    method _block (line 96) | def _block(in_channels, features, name):
  class UNet3D_min (line 129) | class UNet3D_min(nn.Module):
    method __init__ (line 130) | def __init__(self, in_channels=1, out_channels=3, init_features=32):
    method forward (line 170) | def forward(self, x):
    method _block (line 194) | def _block(in_channels, features, name):
  function unet3d (line 226) | def unet3d(in_channels, num_classes):
  function unet3d_min (line 231) | def unet3d_min(in_channels, num_classes):

FILE: models/networks_3d/unet3d_cct.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class FeatureNoise (line 31) | class FeatureNoise(nn.Module):
    method __init__ (line 32) | def __init__(self, uniform_range=0.3):
    method feature_based_noise (line 36) | def feature_based_noise(self, x):
    method forward (line 41) | def forward(self, x):
  function Dropout (line 45) | def Dropout(x, p=0.3):
  function FeatureDropout (line 49) | def FeatureDropout(x):
  class Decoder (line 58) | class Decoder(nn.Module):
    method __init__ (line 59) | def __init__(self, features, out_channels):
    method forward (line 73) | def forward(self, x5, x4, x3, x2, x1):
    method _block (line 92) | def _block(in_channels, features, name):
  class UNet3D_CCT (line 124) | class UNet3D_CCT(nn.Module):
    method __init__ (line 125) | def __init__(self, in_channels=1, out_channels=3, init_features=64):
    method forward (line 151) | def forward(self, x):
    method _block (line 168) | def _block(in_channels, features, name):
  class UNet3D_CCT_min (line 200) | class UNet3D_CCT_min(nn.Module):
    method __init__ (line 201) | def __init__(self, in_channels=1, out_channels=3, init_features=32):
    method forward (line 227) | def forward(self, x):
    method _block (line 244) | def _block(in_channels, features, name):
  function unet3d_cct (line 276) | def unet3d_cct(in_channels, num_classes):
  function unet3d_cct_min (line 281) | def unet3d_cct_min(in_channels, num_classes):

FILE: models/networks_3d/unet3d_dtc.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class UNet3D_DTC (line 32) | class UNet3D_DTC(nn.Module):
    method __init__ (line 33) | def __init__(self, in_channels=1, out_channels=3, init_features=64):
    method forward (line 68) | def forward(self, x):
    method _block (line 94) | def _block(in_channels, features, name):
  function unet3d_dtc (line 127) | def unet3d_dtc(in_channels, num_classes):

FILE: models/networks_3d/unet3d_urpc.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class UnetConv3 (line 31) | class UnetConv3(nn.Module):
    method __init__ (line 32) | def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1...
    method forward (line 52) | def forward(self, inputs):
  class UnetUp3 (line 57) | class UnetUp3(nn.Module):
    method __init__ (line 58) | def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
    method forward (line 72) | def forward(self, inputs1, inputs2):
  class UnetUp3_CT (line 80) | class UnetUp3_CT(nn.Module):
    method __init__ (line 81) | def __init__(self, in_size, out_size, is_batchnorm=True):
    method forward (line 91) | def forward(self, inputs1, inputs2):
  class UnetDsv3 (line 99) | class UnetDsv3(nn.Module):
    method __init__ (line 100) | def __init__(self, in_size, out_size, scale_factor):
    method forward (line 105) | def forward(self, input):
  class unet_3D_dv_semi (line 108) | class unet_3D_dv_semi(nn.Module):
    method __init__ (line 110) | def __init__(self, in_channels=3, n_classes=21, feature_scale=4, is_de...
    method forward (line 168) | def forward(self, inputs):
    method apply_argmax_softmax (line 204) | def apply_argmax_softmax(pred):
  function unet3d_urpc (line 209) | def unet3d_urpc(in_channels, num_classes):

FILE: models/networks_3d/unetr.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  class SingleDeconv3DBlock (line 32) | class SingleDeconv3DBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, in_planes, out_planes):
    method forward (line 37) | def forward(self, x):
  class SingleConv3DBlock (line 41) | class SingleConv3DBlock(nn.Module):
    method __init__ (line 42) | def __init__(self, in_planes, out_planes, kernel_size):
    method forward (line 47) | def forward(self, x):
  class Conv3DBlock (line 51) | class Conv3DBlock(nn.Module):
    method __init__ (line 52) | def __init__(self, in_planes, out_planes, kernel_size=3):
    method forward (line 60) | def forward(self, x):
  class Deconv3DBlock (line 64) | class Deconv3DBlock(nn.Module):
    method __init__ (line 65) | def __init__(self, in_planes, out_planes, kernel_size=3):
    method forward (line 74) | def forward(self, x):
  class SelfAttention (line 78) | class SelfAttention(nn.Module):
    method __init__ (line 79) | def __init__(self, num_heads, embed_dim, dropout):
    method transpose_for_scores (line 97) | def transpose_for_scores(self, x):
    method forward (line 102) | def forward(self, hidden_states):
  class PositionwiseFeedForward (line 141) | class PositionwiseFeedForward(nn.Module):
    method __init__ (line 142) | def __init__(self, d_model=786, d_ff=2048, dropout=0.1):
    method forward (line 149) | def forward(self, x):
  class Embeddings (line 153) | class Embeddings(nn.Module):
    method __init__ (line 154) | def __init__(self, input_dim, embed_dim, cube_size, patch_size, dropout):
    method forward (line 164) | def forward(self, x):
  class TransformerBlock (line 173) | class TransformerBlock(nn.Module):
    method __init__ (line 174) | def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size):
    method forward (line 182) | def forward(self, x):
  class Transformer (line 198) | class Transformer(nn.Module):
    method __init__ (line 199) | def __init__(self, input_dim, embed_dim, cube_size, patch_size, num_he...
    method forward (line 209) | def forward(self, x):
  class UNETR (line 222) | class UNETR(nn.Module):
    method __init__ (line 223) | def __init__(self, input_dim=4, output_dim=3, img_shape=(128, 128, 128...
    method forward (line 306) | def forward(self, x):
  function unertr (line 326) | def unertr(in_channels, num_classes, **kwargs):

FILE: models/networks_3d/vnet.py
  function init_weights (line 8) | def init_weights(net, init_type='normal', gain=0.02):
  function passthrough (line 33) | def passthrough(x, **kwargs):
  function ELUCons (line 37) | def ELUCons(elu, nchan):
  class LUConv (line 44) | class LUConv(nn.Module):
    method __init__ (line 45) | def __init__(self, nchan, elu):
    method forward (line 52) | def forward(self, x):
  function _make_nConv (line 57) | def _make_nConv(nchan, depth, elu):
  class InputTransition (line 64) | class InputTransition(nn.Module):
    method __init__ (line 65) | def __init__(self, in_channels, elu):
    method forward (line 76) | def forward(self, x):
  class DownTransition (line 84) | class DownTransition(nn.Module):
    method __init__ (line 85) | def __init__(self, inChans, nConvs, elu, dropout=False):
    method forward (line 98) | def forward(self, x):
  class UpTransition (line 106) | class UpTransition(nn.Module):
    method __init__ (line 107) | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
    method forward (line 120) | def forward(self, x, skipx):
  class OutputTransition (line 130) | class OutputTransition(nn.Module):
    method __init__ (line 131) | def __init__(self, in_channels, classes, elu):
    method forward (line 140) | def forward(self, x):
  class VNet (line 147) | class VNet(nn.Module):
    method __init__ (line 152) | def __init__(self, in_channels=1, classes=1, elu=True):
    method forward (line 169) | def forward(self, x):
  function vnet (line 182) | def vnet(in_channels, num_classes):

FILE: models/networks_3d/vnet_cct.py
  function init_weights (line 10) | def init_weights(net, init_type='normal', gain=0.02):
  class FeatureNoise (line 33) | class FeatureNoise(nn.Module):
    method __init__ (line 34) | def __init__(self, uniform_range=0.3):
    method feature_based_noise (line 38) | def feature_based_noise(self, x):
    method forward (line 43) | def forward(self, x):
  function Dropout (line 47) | def Dropout(x, p=0.3):
  function FeatureDropout (line 51) | def FeatureDropout(x):
  function passthrough (line 62) | def passthrough(x, **kwargs):
  function ELUCons (line 66) | def ELUCons(elu, nchan):
  class LUConv (line 73) | class LUConv(nn.Module):
    method __init__ (line 74) | def __init__(self, nchan, elu):
    method forward (line 81) | def forward(self, x):
  function _make_nConv (line 86) | def _make_nConv(nchan, depth, elu):
  class InputTransition (line 93) | class InputTransition(nn.Module):
    method __init__ (line 94) | def __init__(self, in_channels, elu):
    method forward (line 105) | def forward(self, x):
  class DownTransition (line 113) | class DownTransition(nn.Module):
    method __init__ (line 114) | def __init__(self, inChans, nConvs, elu, dropout=False):
    method forward (line 127) | def forward(self, x):
  class UpTransition (line 135) | class UpTransition(nn.Module):
    method __init__ (line 136) | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
    method forward (line 149) | def forward(self, x, skipx):
  class OutputTransition (line 159) | class OutputTransition(nn.Module):
    method __init__ (line 160) | def __init__(self, in_channels, classes, elu):
    method forward (line 169) | def forward(self, x):
  class Decoder (line 176) | class Decoder(nn.Module):
    method __init__ (line 177) | def __init__(self, out_channels, elu):
    method forward (line 186) | def forward(self, out256, out128, out64, out32, out16):
  class VNet_CCT (line 195) | class VNet_CCT(nn.Module):
    method __init__ (line 200) | def __init__(self, in_channels=1, classes=1, elu=True):
    method forward (line 218) | def forward(self, x):
  function vnet_cct (line 233) | def vnet_cct(in_channels, num_classes):

FILE: models/networks_3d/vnet_dtc.py
  function init_weights (line 9) | def init_weights(net, init_type='normal', gain=0.02):
  function passthrough (line 34) | def passthrough(x, **kwargs):
  function ELUCons (line 38) | def ELUCons(elu, nchan):
  class LUConv (line 45) | class LUConv(nn.Module):
    method __init__ (line 46) | def __init__(self, nchan, elu):
    method forward (line 53) | def forward(self, x):
  function _make_nConv (line 58) | def _make_nConv(nchan, depth, elu):
  class InputTransition (line 65) | class InputTransition(nn.Module):
    method __init__ (line 66) | def __init__(self, in_channels, elu):
    method forward (line 77) | def forward(self, x):
  class DownTransition (line 85) | class DownTransition(nn.Module):
    method __init__ (line 86) | def __init__(self, inChans, nConvs, elu, dropout=False):
    method forward (line 99) | def forward(self, x):
  class UpTransition (line 107) | class UpTransition(nn.Module):
    method __init__ (line 108) | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
    method forward (line 121) | def forward(self, x, skipx):
  class OutputTransition (line 131) | class OutputTransition(nn.Module):
    method __init__ (line 132) | def __init__(self, in_channels, classes, elu):
    method forward (line 141) | def forward(self, x):
  class VNet_DTC (line 148) | class VNet_DTC(nn.Module):
    method __init__ (line 153) | def __init__(self, in_channels=1, classes=1, elu=True):
    method forward (line 176) | def forward(self, x):
  function vnet_dtc (line 192) | def vnet_dtc(in_channels, num_classes):

FILE: models/networks_3d/xnet3d.py
  function conv1x1 (line 20) | def conv1x1(in_planes, out_planes, stride=1):
  function conv3x3 (line 24) | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  class up_conv (line 28) | class up_conv(nn.Module):
    method __init__ (line 29) | def __init__(self, ch_in, ch_out):
    method forward (line 38) | def forward(self, x):
  class down_conv (line 42) | class down_conv(nn.Module):
    method __init__ (line 43) | def __init__(self, ch_in, ch_out):
    method forward (line 50) | def forward(self, x):
  class same_conv (line 54) | class same_conv(nn.Module):
    method __init__ (line 55) | def __init__(self, ch_in, ch_out):
    method forward (line 62) | def forward(self, x):
  class transition_conv (line 66) | class transition_conv(nn.Module):
    method __init__ (line 67) | def __init__(self, ch_in, ch_out):
    method forward (line 74) | def forward(self, x):
  class BasicBlock (line 78) | class BasicBlock(nn.Module):
    method __init__ (line 81) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 99) | def forward(self, x):
  class DoubleBasicBlock (line 117) | class DoubleBasicBlock(nn.Module):
    method __init__ (line 118) | def __init__(self, inplanes, planes, downsample=None):
    method forward (line 126) | def forward(self, x):
  class XNet3D (line 130) | class XNet3D(nn.Module):
    method __init__ (line 131) | def __init__(self, in_channels, num_classes):
    method forward (line 234) | def forward(self, input1, input2):
  function xnet3d (line 329) | def xnet3d(in_channels, num_classes):

FILE: test.py
  function init_seeds (line 22) | def init_seeds(seed):

FILE: test_3d.py
  function init_seeds (line 23) | def init_seeds(seed):

FILE: test_ConResNet.py
  function init_seeds (line 23) | def init_seeds(seed):

FILE: test_DTC.py
  function init_seeds (line 23) | def init_seeds(seed):

FILE: test_xnet.py
  function init_seeds (line 22) | def init_seeds(seed):

FILE: test_xnet3d.py
  function init_seeds (line 23) | def init_seeds(seed):

FILE: tools/Atrial/postprocess.py
  function save_max_objects (line 8) | def save_max_objects(image):

FILE: tools/LiTS/postprocess.py
  function save_max_objects (line 8) | def save_max_objects(image):

FILE: tools/eval.py
  function eval_distance (line 10) | def eval_distance(mask_list, seg_result_list, num_classes):
  function eval_pixel (line 77) | def eval_pixel(mask_list, seg_result_list, num_classes):

FILE: train_semi_CCT.py
  function init_seeds (line 32) | def init_seeds(seed):

FILE: train_semi_CCT_3d.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_semi_CPS.py
  function init_seeds (line 32) | def init_seeds(seed):

FILE: train_semi_CPS_3d.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_semi_CT.py
  function init_seeds (line 33) | def init_seeds(seed):

FILE: train_semi_CT_3d.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_semi_DTC.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_semi_EM.py
  function init_seeds (line 33) | def init_seeds(seed):

FILE: train_semi_EM_3d.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_semi_MT.py
  function update_ema_variables (line 32) | def update_ema_variables(model, ema_model, alpha, global_step):
  function init_seeds (line 38) | def init_seeds(seed):

FILE: train_semi_MT_3d.py
  function update_ema_variables (line 30) | def update_ema_variables(model, ema_model, alpha, global_step):
  function init_seeds (line 36) | def init_seeds(seed):

FILE: train_semi_UAMT.py
  function update_ema_variables (line 34) | def update_ema_variables(model, ema_model, alpha, global_step):
  function init_seeds (line 41) | def init_seeds(seed):

FILE: train_semi_UAMT_3d.py
  function update_ema_variables (line 31) | def update_ema_variables(model, ema_model, alpha, global_step):
  function init_seeds (line 37) | def init_seeds(seed):

FILE: train_semi_URPC.py
  function init_seeds (line 33) | def init_seeds(seed):

FILE: train_semi_URPC_3d.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_semi_XNet.py
  function init_seeds (line 32) | def init_seeds(seed):

FILE: train_semi_XNet3d.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_sup.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_sup_3d.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_sup_ConResNet.py
  function init_seeds (line 31) | def init_seeds(seed):

FILE: train_sup_XNet.py
  function init_seeds (line 30) | def init_seeds(seed):

FILE: train_sup_XNet3d.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_sup_XNet_sb.py
  function init_seeds (line 30) | def init_seeds(seed):

FILE: train_sup_alnet.py
  function init_seeds (line 29) | def init_seeds(seed):

FILE: train_sup_wds.py
  function init_seeds (line 29) | def init_seeds(seed):
Condensed preview — 108 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,148K chars).
[
  {
    "path": ".idea/XNet.iml",
    "chars": 431,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager"
  },
  {
    "path": ".idea/deployment.xml",
    "chars": 1517,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"PublishConfigData\">\n    <serverData>\n   "
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "chars": 174,
    "preview": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n"
  },
  {
    "path": ".idea/misc.xml",
    "chars": 288,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"JavaScriptSettings\">\n    <option name=\"l"
  },
  {
    "path": ".idea/modules.xml",
    "chars": 260,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n   "
  },
  {
    "path": ".idea/vcs.xml",
    "chars": 180,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping dire"
  },
  {
    "path": ".idea/workspace.xml",
    "chars": 33285,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"t"
  },
  {
    "path": "LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2024 Yanfeng Zhou\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "README.md",
    "chars": 10714,
    "preview": "\n# XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomed"
  },
  {
    "path": "config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/augmentation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/augmentation/online_aug.py",
    "chars": 1995,
    "preview": "import albumentations as A\nfrom albumentations.pytorch import ToTensorV2\nfrom torchio import transforms as T\nimport torc"
  },
  {
    "path": "config/dataset_config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/dataset_config/dataset_cfg.py",
    "chars": 2784,
    "preview": "import numpy as np\nimport torchio as tio\n\ndef dataset_cfg(dataet_name):\n\n    config = {\n        'CREMI':\n            {\n "
  },
  {
    "path": "config/eval_config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/eval_config/eval.py",
    "chars": 1604,
    "preview": "import numpy as np\nfrom sklearn.metrics import confusion_matrix\nfrom scipy.spatial.distance import directed_hausdorff\nim"
  },
  {
    "path": "config/ramps/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/ramps/ramps.py",
    "chars": 784,
    "preview": "import numpy as np\n\n\ndef sigmoid_rampup(current, rampup_length):\n    \"\"\"Exponential rampup from https://arxiv.org/abs/16"
  },
  {
    "path": "config/train_test_config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/train_test_config/train_test_config.py",
    "chars": 25853,
    "preview": "import numpy as np\nfrom config.eval_config.eval import evaluate, evaluate_multi\nimport torch\nimport os\nfrom PIL import I"
  },
  {
    "path": "config/visdom_config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/visdom_config/visual_visdom.py",
    "chars": 7846,
    "preview": "from visdom import Visdom\nimport os\n\ndef visdom_initialization_sup(env, port):\n    visdom = Visdom(env=env, port=port)\n "
  },
  {
    "path": "config/warmup_config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "config/warmup_config/warmup.py",
    "chars": 3139,
    "preview": "from torch.optim.lr_scheduler import _LRScheduler\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\n\nclass Gradual"
  },
  {
    "path": "dataload/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "dataload/dataset_2d.py",
    "chars": 14326,
    "preview": "import os\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom PIL import Image\nimport cv2\nimport numpy as"
  },
  {
    "path": "dataload/dataset_3d.py",
    "chars": 9213,
    "preview": "import os\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom PIL import Image\nimport cv2\nimport numpy as"
  },
  {
    "path": "loss/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "loss/loss_function.py",
    "chars": 8589,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.autograd import Variabl"
  },
  {
    "path": "models/__init__.py",
    "chars": 1505,
    "preview": "# 2d\nfrom .networks_2d.xnet import XNet, XNet_1_1_m, XNet_1_2_m, XNet_2_1_m, XNet_3_2_m, XNet_2_3_m, XNet_3_3_m, XNet_sb"
  },
  {
    "path": "models/getnetwork.py",
    "chars": 4074,
    "preview": "import sys\nfrom models import *\nimport torch.nn as nn\n\ndef get_network(network, in_channels, num_classes, **kwargs):\n\n  "
  },
  {
    "path": "models/networks_2d/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/networks_2d/aerial_lanenet.py",
    "chars": 5544,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch"
  },
  {
    "path": "models/networks_2d/hrnet.py",
    "chars": 25569,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n"
  },
  {
    "path": "models/networks_2d/mwcnn.py",
    "chars": 16899,
    "preview": "import torch\nimport torch.nn as nn\nimport scipy.io as sio\nimport math\nimport torch.nn.functional as F\nfrom torch.autogra"
  },
  {
    "path": "models/networks_2d/resunet.py",
    "chars": 4536,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    "
  },
  {
    "path": "models/networks_2d/resunet_plusplus.py",
    "chars": 8387,
    "preview": "import torch.nn as nn\nimport torch\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    "
  },
  {
    "path": "models/networks_2d/swinunet.py",
    "chars": 36849,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport cop"
  },
  {
    "path": "models/networks_2d/u2net.py",
    "chars": 16208,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\ndef init_weights(net, init"
  },
  {
    "path": "models/networks_2d/unet.py",
    "chars": 13619,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\n\ndef init_weights(net, ini"
  },
  {
    "path": "models/networks_2d/unet_3plus.py",
    "chars": 44418,
    "preview": "# -*- coding: utf-8 -*-\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimp"
  },
  {
    "path": "models/networks_2d/unet_cct.py",
    "chars": 7805,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.distributions.uniform import Uniform\nfrom torch.nn impo"
  },
  {
    "path": "models/networks_2d/unet_plusplus.py",
    "chars": 5750,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0.02):\n    d"
  },
  {
    "path": "models/networks_2d/unet_urpc.py",
    "chars": 8970,
    "preview": "from __future__ import division, print_function\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.distrib"
  },
  {
    "path": "models/networks_2d/wavesnet.py",
    "chars": 53653,
    "preview": "import numpy as np\nimport math, pywt\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Module\nfrom torch.autograd "
  },
  {
    "path": "models/networks_2d/wds.py",
    "chars": 5255,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch"
  },
  {
    "path": "models/networks_2d/xnet.py",
    "chars": 66700,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch"
  },
  {
    "path": "models/networks_3d/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/networks_3d/conresnet.py",
    "chars": 14779,
    "preview": "import torch.nn as nn\nfrom torch.nn import functional as F\nimport torch\nimport numpy as np\nfrom torch.nn import init\n# f"
  },
  {
    "path": "models/networks_3d/cotr.py",
    "chars": 27124,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.nn.init import xavier_u"
  },
  {
    "path": "models/networks_3d/dmfnet.py",
    "chars": 10385,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\nimport torch\n# from loss.loss_function import segmentation_loss\n\nd"
  },
  {
    "path": "models/networks_3d/espnet3d.py",
    "chars": 13607,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n# from loss.loss_function import segmenta"
  },
  {
    "path": "models/networks_3d/res_unet3d.py",
    "chars": 9829,
    "preview": "import torch\nimport torch.nn as nn\nimport os\nfrom torch.nn import init\n\ndef init_weights(net, init_type='normal', gain=0"
  },
  {
    "path": "models/networks_3d/transbts.py",
    "chars": 19009,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport torch.nn.functional as F\nfrom loss.loss_function imp"
  },
  {
    "path": "models/networks_3d/unet3d.py",
    "chars": 9334,
    "preview": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\n\ndef"
  },
  {
    "path": "models/networks_3d/unet3d_cct.py",
    "chars": 11786,
    "preview": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom"
  },
  {
    "path": "models/networks_3d/unet3d_dtc.py",
    "chars": 5814,
    "preview": "import numpy as np\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\n# fr"
  },
  {
    "path": "models/networks_3d/unet3d_urpc.py",
    "chars": 8724,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\n\ndef init_weig"
  },
  {
    "path": "models/networks_3d/unetr.py",
    "chars": 11886,
    "preview": "import copy\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom torch.nn import init\n\nde"
  },
  {
    "path": "models/networks_3d/vnet.py",
    "chars": 6405,
    "preview": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import"
  },
  {
    "path": "models/networks_3d/vnet_cct.py",
    "chars": 8562,
    "preview": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import"
  },
  {
    "path": "models/networks_3d/vnet_dtc.py",
    "chars": 6965,
    "preview": "import torch\nimport torch.nn as nn\nimport os\nimport numpy as np\nfrom collections import OrderedDict\nfrom torch.nn import"
  },
  {
    "path": "models/networks_3d/xnet3d.py",
    "chars": 13156,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\nimport functools\nfrom torch"
  },
  {
    "path": "requirements.txt",
    "chars": 325,
    "preview": "albumentations==0.5.2\neinops==0.4.1\nMedPy==0.4.0\nnumpy==1.20.2\nopencv_python==4.2.0.34\nopencv_python_headless==4.5.1.48\n"
  },
  {
    "path": "test.py",
    "chars": 7194,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "test_3d.py",
    "chars": 5509,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "test_ConResNet.py",
    "chars": 5546,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "test_DTC.py",
    "chars": 5415,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "test_xnet.py",
    "chars": 8002,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "test_xnet3d.py",
    "chars": 5755,
    "preview": "from torchvision import transforms, datasets\nimport torch\nfrom torch.autograd import Variable\nfrom torch.utils.data impo"
  },
  {
    "path": "tools/Atrial/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tools/Atrial/postprocess.py",
    "chars": 1581,
    "preview": "import numpy as np\nimport argparse\nimport os\nimport SimpleITK as sitk\nfrom skimage.morphology import remove_small_object"
  },
  {
    "path": "tools/Atrial/preprocess.py",
    "chars": 2165,
    "preview": "import numpy as np\nimport torchio as tio\nimport os\nimport argparse\nfrom tqdm import tqdm\nimport SimpleITK as sitk\n\nif __"
  },
  {
    "path": "tools/LiTS/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tools/LiTS/postprocess.py",
    "chars": 1676,
    "preview": "import numpy as np\nimport argparse\nimport os\nimport SimpleITK as sitk\nfrom skimage.morphology import remove_small_object"
  },
  {
    "path": "tools/LiTS/preprocess.py",
    "chars": 3665,
    "preview": "import numpy as np\nimport os\nimport argparse\nfrom tqdm import tqdm\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n"
  },
  {
    "path": "tools/LiTS/split_train_val.py",
    "chars": 1245,
    "preview": "import numpy as np\nimport os\nimport argparse\nimport shutil\nimport random\n\nif __name__ == '__main__':\n\n    parser = argpa"
  },
  {
    "path": "tools/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tools/eval.py",
    "chars": 5331,
    "preview": "from sklearn.metrics import confusion_matrix\nimport numpy as np\nimport argparse\nimport os\nfrom PIL import Image\nfrom med"
  },
  {
    "path": "tools/mask2sdf.py",
    "chars": 1747,
    "preview": "import numpy as np\nimport os\nimport argparse\nimport SimpleITK as sitk\nfrom scipy.ndimage import distance_transform_edt\nf"
  },
  {
    "path": "tools/res_image_mask.py",
    "chars": 1869,
    "preview": "import numpy as np\nimport os\nimport argparse\nimport SimpleITK as sitk\n\nif __name__ == '__main__':\n\n    parser = argparse"
  },
  {
    "path": "tools/wavelet2D.py",
    "chars": 1738,
    "preview": "import numpy as np\nfrom PIL import Image\nimport pywt\nimport argparse\nimport os\n\nif __name__ == '__main__':\n\n    parser ="
  },
  {
    "path": "tools/wavelet3D.py",
    "chars": 2962,
    "preview": "import numpy as np\nfrom PIL import Image\nimport pywt\nimport argparse\nimport os\nimport SimpleITK as sitk\nimport torchio a"
  },
  {
    "path": "train_semi_CCT.py",
    "chars": 16693,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_CCT_3d.py",
    "chars": 16868,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_CPS.py",
    "chars": 18519,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_CPS_3d.py",
    "chars": 18596,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_CT.py",
    "chars": 18639,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_CT_3d.py",
    "chars": 18768,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_DTC.py",
    "chars": 16963,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_EM.py",
    "chars": 15948,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_EM_3d.py",
    "chars": 16122,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_MT.py",
    "chars": 17872,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_MT_3d.py",
    "chars": 17997,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_UAMT.py",
    "chars": 18861,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_UAMT_3d.py",
    "chars": 19016,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_URPC.py",
    "chars": 18138,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_URPC_3d.py",
    "chars": 18280,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_XNet.py",
    "chars": 18616,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_semi_XNet3d.py",
    "chars": 18205,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup.py",
    "chars": 13758,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_3d.py",
    "chars": 13267,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_ConResNet.py",
    "chars": 15164,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_XNet.py",
    "chars": 16661,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_XNet3d.py",
    "chars": 16386,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_XNet_sb.py",
    "chars": 17490,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_alnet.py",
    "chars": 14030,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  },
  {
    "path": "train_sup_wds.py",
    "chars": 13691,
    "preview": "from torchvision import transforms, datasets\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.o"
  }
]

About this extraction

This page contains the full source code of the Yanfeng-Zhou/XNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 108 files (1.1 MB), approximately 309.1k tokens, and a symbol index with 1024 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!