[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.pyc\n*_bkp.py\ncheckpoint/\nlog/\nmodel/backup/\ninitmodel/\n*.swp\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 yhygao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# UTNet (Accepted at MICCAI 2021)\nOfficial implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation](https://arxiv.org/abs/2107.00781)\n\n\n## Update\n* Our new work, Hermes, has been released on arXiv: [Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning](https://arxiv.org/pdf/2306.02416.pdf). Inspired by the training of medical residents, we explore universal medical image segmentation, whose goal is to learn from diverse medical imaging sources covering a range of clinical targets, body regions, and image modalities. Following this paradigm, we propose Hermes, a context prior learning approach that addresses the challenges related to the heterogeneity on data, modality, and annotations in the proposed universal paradigm. Code will be released at https://github.com/yhygao/universal-medical-image-segmentation.\n* Our new paper, the improved version of UTNet: UTNetV2, is released on Arxiv: [A Multi-scale Transformer for Medical Image Segmentation: Architectures, Model Efficiency, and Benchmarks](https://arxiv.org/abs/2203.00131). The UTNetV2 has an improved architecture for 2D and 3D setting. We also provide a more general framework for CNN and Transformer comparison, including more dataset support, more SOTA model support, see in our new repo：https://github.com/yhygao/CBIM-Medical-Image-Segmentation\n\nData preprocess code uploaded.\n\n\n## Introduction \n\nTransformer architecture has emerged to be successful in a\nnumber of natural language processing tasks. However, its applications\nto medical vision remain largely unexplored. In this study, we present\nUTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing\nmedical image segmentation. UTNet applies self-attention modules in\nboth encoder and decoder for capturing long-range dependency at dif-\nferent scales with minimal overhead. To this end, we propose an efficient\nself-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2)\nto approximate O(n). A new self-attention decoder is also proposed to\nrecover fine-grained details from the skipped connections in the encoder.\nOur approach addresses the dilemma that Transformer requires huge\namounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks\nwithout a need of pre-training. We have evaluated UTNet on the multi-\nlabel, multi-vendor cardiac magnetic resonance imaging cohort. UTNet\ndemonstrates superior segmentation performance and robustness against\nthe state-of-the-art approaches, holding the promise to generalize well on\nother medical image segmentations.\n\n![image](https://user-images.githubusercontent.com/55367673/134997310-69c3576d-bbf2-40c8-ad5a-9b3bf3e9e97d.png)\n![image](https://user-images.githubusercontent.com/55367673/134997347-a581cda7-7050-48ef-9af3-d4628fefac9a.png)\n\n## Supportting models\nUTNet\n\nTransUNet\n\nResNet50-UTNet\n\nResNet50-UNet\n\nSwinUNet\n\nTo be continue ...\n\n## Getting Started\n\nCurrently, we only support [M&Ms dataset](https://www.ub.edu/mnms/).\n\n### Prerequisites\n```\nPython >= 3.6\npytorch = 1.8.1\nSimpleITK = 2.0.2\nnumpy = 1.19.5\neinops = 0.3.2\n```\n\n### Preprocess\n\nResample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'\n\n### Training\n\nThe M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.\n\n#### UTNet\nFor default UTNet setting, training with:\n```\npython train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss\n```\nOr you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.\n\nTo optimize UTNet in your own task, there are several hyperparameters to tune:\n\n'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.\n\n'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.\n\n'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.\n\n'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.\n\nHere are some recomended parameter setting:\n```\n--block_list 1234 --num_blocks 1,1,1,1\n```\nOur default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.\n\n```\n--block_list 1234 --num_blocks 1,1,4,8\n```\nSimilar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.\n\n```\n--block_list 234 --num_blocks 2,4,8\n```\nSuitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.\n\n\nFeel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.\n\n#### TransUNet\nWe borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the [original repo](https://github.com/Beckschen/TransUNet). The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.\n```\npython train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0\n```\n\n#### ResNet50-UTNet\nFor fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.\n```\npython train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0\n```\nSimilar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.\n```\n--block_list 23 --num_blocks 2,4\n```\nSuitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.\n\n\n#### ResNet50-UNet\nIf you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.\n\n```\npython train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0\n```\n\n#### SwinUNet\nDownload pre-trained model from the [origin repo](https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912).\nAs Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.\n```\npython train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224\n```\n\n## Citation\nIf you find this repo helps, please kindly cite our paper, thanks!\n```\n@inproceedings{gao2021utnet,\n  title={UTNet: a hybrid transformer architecture for medical image segmentation},\n  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},\n  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},\n  pages={61--71},\n  year={2021},\n  organization={Springer}\n}\n\n@misc{gao2022datascalable,\n      title={A Data-scalable Transformer for Medical Image Segmentation: Architecture, Model Efficiency, and Benchmark}, \n      author={Yunhe Gao and Mu Zhou and Di Liu and Zhennan Yan and Shaoting Zhang and Dimitris N. Metaxas},\n      year={2022},\n      eprint={2203.00131},\n      archivePrefix={arXiv},\n      primaryClass={eess.IV}\n}\n\n```\n\n"
  },
  {
    "path": "data_preprocess.py",
    "content": "import numpy as np\nimport SimpleITK as sitk\n\nimport os\nimport pdb \n\ndef ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear):\n    identity1 = sitk.Transform(3, sitk.sitkIdentity)\n    sp1 = imImage.GetSpacing()\n    sz1 = imImage.GetSize()\n\n    sz2 = (int(round(sz1[0]*sp1[0]*1.0/space[0])), int(round(sz1[1]*sp1[1]*1.0/space[1])), int(round(sz1[2]*sp1[2]*1.0/space[2])))\n\n    imRefImage = sitk.Image(sz2, imImage.GetPixelIDValue())\n    imRefImage.SetSpacing(space)\n    imRefImage.SetOrigin(imImage.GetOrigin())\n    imRefImage.SetDirection(imImage.GetDirection())\n\n    imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)\n\n    return imOutImage\n\ndef ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighbor):\n    identity1 = sitk.Transform(3, sitk.sitkIdentity)\n\n    imRefImage = sitk.Image(imRef.GetSize(), imImage.GetPixelIDValue())\n    imRefImage.SetSpacing(imRef.GetSpacing())\n    imRefImage.SetOrigin(imRef.GetOrigin())\n    imRefImage.SetDirection(imRef.GetDirection())\n\n\n    imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)\n\n    return imOutImage\n\n\ndef ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, target_space=(1., 1.)):\n    \n    assert imImage.GetSpacing() == imLabel.GetSpacing()\n    assert imImage.GetSize() == imLabel.GetSize()\n\n\n    spacing = imImage.GetSpacing()\n    origin = imImage.GetOrigin()\n\n\n    npimg = sitk.GetArrayFromImage(imImage)\n    nplab = sitk.GetArrayFromImage(imLabel)\n    t, z, y, x = npimg.shape\n    \n    if not os.path.exists('%s/%s'%(save_path, patient_name)):\n        os.mkdir('%s/%s'%(save_path, patient_name))\n    flag = 0\n    for i in range(t):\n        tmp_img = npimg[i]\n        tmp_lab = nplab[i]\n        \n        if tmp_lab.max() == 0:\n            continue\n\n        \n        print(i)\n        flag += 1\n        tmp_itkimg = sitk.GetImageFromArray(tmp_img)\n        tmp_itkimg.SetSpacing(spacing[0:3])\n        tmp_itkimg.SetOrigin(origin[0:3])\n            \n        tmp_itklab = sitk.GetImageFromArray(tmp_lab)\n        tmp_itklab.SetSpacing(spacing[0:3])\n        tmp_itklab.SetOrigin(origin[0:3])\n\n        \n        re_img = ResampleXYZAxis(tmp_itkimg, space=(target_space[0], target_space[1], spacing[2]))\n        re_lab = ResampleFullImageToRef(tmp_itklab, re_img)\n\n        \n        sitk.WriteImage(re_img, '%s/%s/%s_%d.nii.gz'%(save_path, patient_name, name, flag))\n        sitk.WriteImage(re_lab, '%s/%s/%s_gt_%d.nii.gz'%(save_path, patient_name, name, flag))\n        \n    return flag\n\n\nif __name__ == '__main__':\n    \n\n    src_path = 'OpenDataset/Training/Labeled'\n    tgt_path = 'dataset/Training'\n\n    os.chdir(src_path)\n    for name in os.listdir('.'):\n        os.chdir(name)\n\n        for i in os.listdir('.'):\n            if 'gt' in i:\n                tmp = i.split('_')\n                img_name = tmp[0] + '_' + tmp[1]\n                patient_name = tmp[0]\n                \n                img = sitk.ReadImage('%s.nii.gz'%img_name)\n                lab = sitk.ReadImage('%s_gt.nii.gz'%img_name)\n                \n                flag = ResampleCMRImage(img, lab, tgt_path, patient_name, img_name, (1.2, 1.2))\n                \n                print(name, 'done', flag)\n\n        os.chdir('..')\n\n\n\n\n"
  },
  {
    "path": "dataset/Testing_A.csv",
    "content": "name\nA2L1N6\nB0H7V0\nB3F0V9\nB5F8L9\nB9H8N8\nC4E9I1\nC6E0F9\nC6U4W8\nD1H6U2\nD9F5P1\nE0J7L9\nF0K4T6\nG7S6V0\nG9N5V9\nH1N8S6\nI2J6Z6\nI5L3S2\nJ4J8Q3\nK6N4N7\nP3P9S5\n"
  },
  {
    "path": "dataset/Testing_B.csv",
    "content": "name\nA2H5K9\nA4A8V9\nA5C2D2\nA5D0G0\nA6J0Y2\nA7E4J0\nB1G9J3\nB3S2Z4\nB4E1K1\nB4S1Y2\nB5L5Y4\nB5T6V0\nB7F5P0\nC7L8Z8\nC8I7P7\nD1R0Y5\nD3Q0W9\nD6E9U8\nD8O0W2\nD9I8O7\nE1L8Y4\nE3L8U8\nE4I9O7\nE4O8P3\nE5S7W7\nE6J4N8\nG1K1V3\nG3M5S4\nG7N8R7\nG7Q2W0\nH7K5U5\nH8K2K7\nI0I2J8\nI6T4W8\nI7W4Y8\nJ9L4S2\nK5K6N1\nK7L2Y6\nK9N0W0\nL7Y7Z2\nL8N7Z0\nM4T4V6\nM6V2Y0\nN9P5Z0\nO4T6Y7\nO9V8W5\nP3R6Y5\nP8W4Z0\nQ3Q6R8\nY6Y9Z2\n"
  },
  {
    "path": "dataset/Testing_C.csv",
    "content": "name\nA3H5R1\nA5H1Q2\nA8C5E9\nA9F3T5\nA9L7Y7\nB0L3Y2\nB2L0L2\nB8P5Q9\nC0N8P4\nC7M6W0\nC8J7L5\nC8O0P2\nD1H2O9\nD2U0V0\nD3K5Q2\nD5G3W8\nD6N7Q8\nD7M8P9\nD7T3V8\nE3F2U7\nE3F5U2\nE6M6P2\nE7L0N6\nE9V9Z2\nF0I6U8\nF1K2S9\nF5I1Z8\nG8R0Z9\nH2M9S1\nH3R6S9\nH7L8R8\nH7P5Z4\nI4R8V6\nI6P4R0\nI8Z0Z6\nJ6K4V3\nK3R0Y7\nK7O3Q0\nL2V5Z0\nL8N7P0\nM6M9N1\nO4O6U5\nP3T5U1\nQ1Q3T1\nQ5V8W3\nR1R6Y8\nR2R7Z5\nR6V5W3\nR8V0Y4\nV4W8Z5\n"
  },
  {
    "path": "dataset/Testing_D.csv",
    "content": "name\nA1K2P5\nA3P9V7\nA4B9O6\nA4K8R4\nA4R4T0\nA5P5W0\nA5Q1W8\nA6A8H0\nA6B7Y4\nA7F4G2\nB3E2W8\nB6I0T4\nB9G4U2\nC0L7V1\nC5L0R0\nD1L6T4\nD1S5T8\nE0J2Z9\nE1L7M3\nE4H7L4\nE5J6L2\nE6H0V9\nF6J9L9\nF9M1R2\nG1J5K3\nG4I7V2\nG8K0M3\nG8L0Z0\nH6P7T1\nI4L4V7\nI8N8Y1\nJ6M5O2\nK3P3Y6\nK5L4S1\nK5M7V5\nK7N0R7\nL5U7Y4\nL6T2T5\nL8M2U8\nM2P5T8\nN2O7U5\nN7P3T8\nN7W6Z8\nN9Q4T8\nO5U2U7\nO7Q7U3\nP8V0Y7\nQ4W5Z8\nR3V5W7\nT2Z1Z9\n"
  },
  {
    "path": "dataset/Training_A.csv",
    "content": "name\nA0S9V9\nA1D9Z7\nA1E9Q1\nA2C0I1\nA2N8V0\nA3H1O5\nA4B5U4\nA4J4S4\nA4U9V5\nA6B5G9\nA6D5F9\nA7D9L8\nA7M7P8\nA7O4T6\nB0I2Z0\nB2C2Z7\nB2D9M2\nB2G5R2\nB3O1S0\nB3P3R1\nB4O3V3\nB8H5H6\nB8J7R4\nB9E0Q1\nC0K1P0\nC1K8P5\nC2J0K3\nC5M4S2\nC6J5P1\nC8P3S7\nD0R0R9\nD1J5P6\nD1M1S6\nD3D4Y5\nD3F3O5\nD3F9H9\nD3O9U9\nD4M3Q2\nD4N6W6\nD8E4F4\nD9L1Z3\nE0M3U7\nE0O0S0\nE9H1U4\nE9V4Z8\nF2H5S1\nF3G5K5\nG4S9U3\nG5P4U3\nG7I5V7\nH1I3W0\nH5N0P0\nI7T3U1\nJ1T9Y1\nJ6K6P5\nJ6P5T8\nJ8R5W2\nK4T7Y0\nK5L2U3\nK5P0Y1\nL1Q9V8\nL4Q2U3\nM1R4S1\nM2P1R1\nN5S7Y1\nN8N9U0\nO0S9V7\nP0S5Y0\nP5R1Y4\nQ0U0V5\nQ3R9W7\nR4Y1Z9\nS1S3Z7\nT2T9Z9\nT9U9W2\n"
  },
  {
    "path": "dataset/Training_B.csv",
    "content": "name\nA1D0Q7\nA1O8Z3\nA3B7E5\nA5E0T8\nA6M1Q7\nA7G0P5\nA8C9U8\nA8E1F4\nA9C5P4\nA9E3G9\nA9J5Q7\nA9J8W7\nB0N3W8\nB2D9O2\nB2F4K5\nB3D0N1\nB6D0U7\nB9O1Q0\nC0S7W0\nC1G5Q0\nC2L5P7\nC2M6P8\nC3I2K3\nC4R8T7\nC4S8W9\nD0H9I4\nD1L4Q9\nD6H6O2\nE3T0Z2\nE4M2Q7\nE4W8Z7\nE5E6O8\nE5F5V7\nE9H2K7\nE9L1W5\nF0J2R8\nF1F3I6\nF4K3S1\nF5I9Q2\nF8N2S1\nG0H4J3\nG0I6P3\nG1N6S7\nG2J1M5\nG2M7W4\nG2O2S6\nG4L8Z7\nG8N2U5\nG9L0O9\nH0K3Q4\nH1J5W8\nH1M5Y6\nH1W2Y1\nH3U1Y1\nH4I2T8\nH6I0I6\nH7I4J3\nH7N4V9\nI0J5U3\nI2K2Y8\nI6N3P3\nJ4J9W6\nJ9L6N9\nK2S1U6\nL1Q1Z5\nL5Q6T7\nM0P8U8\nM4P7Q6\nN1P8Q9\nN7V9W9\nO3R8Y5\nP6U0Y0\nP9S7W2\nQ7V1Y5\nW5Z4Z8\n"
  },
  {
    "path": "dataset/Training_C.csv",
    "content": "name\n"
  },
  {
    "path": "dataset/Training_D.csv",
    "content": "name\n"
  },
  {
    "path": "dataset/info.csv",
    "content": "External code,VendorName,Vendor,Centre,ED,ES\nA0S9V9,Siemens,A,1,0,9\nA1D0Q7,Philips,B,2,0,9\nA1D9Z7,Siemens,A,1,22,11\nA1E9Q1,Siemens,A,1,0,9\nA1K2P5,Canon,D,5,33,11\nA1O8Z3,Philips,B,3,23,10\nA2C0I1,Siemens,A,1,0,7\nA2E3W4,GE,C,4,0,0\nA2H5K9,Philips,B,2,29,8\nA2L1N6,Siemens,A,1,0,12\nA2N8V0,Siemens,A,1,0,9\nA3B7E5,Philips,B,2,29,12\nA3H1O5,Siemens,A,1,0,12\nA3H5R1,GE,C,4,24,6\nA3P9V7,Canon,D,5,27,13\nA4A8V9,Philips,B,3,0,10\nA4B5U4,Siemens,A,1,0,10\nA4B9O6,Canon,D,5,0,11\nA4J4S4,Siemens,A,1,0,7\nA4K8R4,Canon,D,5,24,11\nA4R4T0,Canon,D,5,21,8\nA4U9V5,Siemens,A,1,0,8\nA5C2D2,Philips,B,3,24,9\nA5D0G0,Philips,B,2,0,10\nA5E0T8,Philips,B,3,24,8\nA5H1Q2,GE,C,4,0,9\nA5P5W0,Canon,D,5,33,10\nA5Q1W8,Canon,D,5,27,10\nA6A8H0,Canon,D,5,27,8\nA6B5G9,Siemens,A,1,0,11\nA6B7Y4,Canon,D,5,1,10\nA6D5F9,Siemens,A,1,0,11\nA6J0Y2,Philips,B,2,27,13\nA6M1Q7,Philips,B,2,29,11\nA7D9L8,Siemens,A,1,0,11\nA7E4J0,Philips,B,3,1,12\nA7F4G2,Canon,D,5,25,10\nA7G0P5,Philips,B,2,28,9\nA7M7P8,Siemens,A,1,0,9\nA7O4T6,Siemens,A,1,0,10\nA8C5E9,GE,C,4,24,9\nA8C9H8,GE,C,4,0,0\nA8C9U8,Philips,B,2,28,9\nA8E1F4,Philips,B,3,24,9\nA8I1U6,GE,C,4,0,0\nA9C5P4,Philips,B,2,29,8\nA9E3G9,Philips,B,3,23,8\nA9F3T5,GE,C,4,24,8\nA9J5Q7,Philips,B,3,24,7\nA9J8W7,Philips,B,2,29,10\nA9L7Y7,GE,C,4,24,5\nB0H7V0,Siemens,A,1,0,9\nB0I2Z0,Siemens,A,1,0,8\nB0L3Y2,GE,C,4,24,10\nB0N3W8,Philips,B,2,28,15\nB1G9J3,Philips,B,2,29,10\nB1K7U1,GE,C,4,0,0\nB2C2Z7,Siemens,A,1,0,8\nB2D9M2,Siemens,A,1,0,8\nB2D9O2,Philips,B,2,29,13\nB2F4K5,Philips,B,2,29,10\nB2G5R2,Siemens,A,1,0,7\nB2L0L2,GE,C,4,24,8\nB3D0N1,Philips,B,3,24,8\nB3E2W8,Canon,D,5,23,9\nB3F0V9,Siemens,A,1,0,12\nB3O1S0,Siemens,A,1,0,8\nB3P3R1,Siemens,A,1,0,11\nB3S2Z4,Philips,B,2,28,11\nB4E1K1,Philips,B,3,0,9\nB4I8Z7,GE,C,4,0,0\nB4O3V3,Siemens,A,1,0,6\nB4S1Y2,Philips,B,2,29,10\nB5F8L9,Siemens,A,1,0,8\nB5L5Y4,Philips,B,3,24,10\nB5T6V0,Philips,B,3,0,10\nB6D0U7,Philips,B,2,29,9\nB6I0T4,Canon,D,5,33,10\nB7F5P0,Philips,B,3,29,12\nB8H5H6,Siemens,A,1,24,8\nB8J7R4,Siemens,A,1,0,12\nB8P5Q9,GE,C,4,23,9\nB9E0Q1,Siemens,A,1,0,11\nB9G4U2,Canon,D,5,28,8\nB9H8N8,Siemens,A,1,0,10\nB9O1Q0,Philips,B,2,29,11\nC0K1P0,Siemens,A,1,0,9\nC0L7V1,Canon,D,5,33,14\nC0N8P4,GE,C,4,24,9\nC0S7W0,Philips,B,3,24,7\nC0W2Y5,GE,C,4,0,0\nC1G5Q0,Philips,B,3,24,8\nC1K8P5,Siemens,A,1,0,8\nC2J0K3,Siemens,A,1,0,9\nC2L5P7,Philips,B,2,28,12\nC2M6P8,Philips,B,3,24,8\nC3I2K3,Philips,B,2,29,11\nC4E9I1,Siemens,A,1,0,8\nC4K8M1,GE,C,4,0,0\nC4R8T7,Philips,B,2,28,8\nC4S8W9,Philips,B,3,24,8\nC5L0R0,Canon,D,5,29,12\nC5M4S2,Siemens,A,1,0,9\nC5Q2Y5,GE,C,4,0,0\nC6E0F9,Siemens,A,1,0,8\nC6J5P1,Siemens,A,1,0,10\nC6U4W8,Siemens,A,1,0,9\nC7L8Z8,Philips,B,2,29,9\nC7M6W0,GE,C,4,24,8\nC8I7P7,Philips,B,2,29,11\nC8J7L5,GE,C,4,24,7\nC8O0P2,GE,C,4,24,8\nC8P3S7,Siemens,A,1,0,8\nC8V5W8,GE,C,4,0,0\nD0E2W0,GE,C,4,0,0\nD0H9I4,Philips,B,2,0,10\nD0R0R9,Siemens,A,1,0,11\nD1H2O9,GE,C,4,24,8\nD1H6U2,Siemens,A,1,0,8\nD1J5P6,Siemens,A,1,0,13\nD1L4Q9,Philips,B,2,29,10\nD1L6T4,Canon,D,5,27,10\nD1M1S6,Siemens,A,1,0,10\nD1R0Y5,Philips,B,3,24,9\nD1S5T8,Canon,D,5,23,14\nD2U0V0,GE,C,4,24,10\nD3D4Y5,Siemens,A,1,0,9\nD3F3O5,Siemens,A,1,24,10\nD3F9H9,Siemens,A,1,24,8\nD3K5Q2,GE,C,4,11,0\nD3O9U9,Siemens,A,1,0,10\nD3Q0W9,Philips,B,3,24,10\nD4M3Q2,Siemens,A,1,0,9\nD4N6W6,Siemens,A,1,0,10\nD5G3W8,GE,C,4,23,8\nD6E9U8,Philips,B,2,0,12\nD6H6O2,Philips,B,2,29,9\nD6N7Q8,GE,C,4,24,9\nD7M8P9,GE,C,4,24,8\nD7T3V8,GE,C,4,24,7\nD8E4F4,Siemens,A,1,0,10\nD8O0W2,Philips,B,3,24,8\nD9F5P1,Siemens,A,1,0,10\nD9I8O7,Philips,B,3,29,11\nD9L1Z3,Siemens,A,1,0,12\nE0J2Z9,Canon,D,5,25,11\nE0J7L9,Siemens,A,1,0,9\nE0M3U7,Siemens,A,1,0,9\nE0O0S0,Siemens,A,1,0,11\nE1L7M3,Canon,D,5,1,12\nE1L8Y4,Philips,B,3,24,10\nE3F2U7,GE,C,4,0,7\nE3F5U2,GE,C,4,24,8\nE3I4V1,GE,C,4,0,0\nE3L8U8,Philips,B,3,29,9\nE3T0Z2,Philips,B,2,29,10\nE4H7L4,Canon,D,5,28,10\nE4I9O7,Philips,B,3,29,10\nE4M2Q7,Philips,B,3,0,8\nE4O8P3,Philips,B,2,29,12\nE4W8Z7,Philips,B,2,29,10\nE5E6O8,Philips,B,2,29,10\nE5F5V7,Philips,B,3,0,9\nE5J6L2,Canon,D,5,29,11\nE5S7W7,Philips,B,2,28,12\nE6H0V9,Canon,D,5,33,12\nE6J4N8,Philips,B,2,29,11\nE6M6P2,GE,C,4,24,7\nE7L0N6,GE,C,4,22,8\nE9H1U4,Siemens,A,1,0,10\nE9H2K7,Philips,B,2,29,11\nE9L1W5,Philips,B,3,24,9\nE9L4N2,GE,C,4,0,0\nE9V4Z8,Siemens,A,1,0,8\nE9V9Z2,GE,C,4,28,10\nF0I6U8,GE,C,4,10,23\nF0J2R8,Philips,B,2,29,9\nF0K4T6,Siemens,A,1,0,8\nF1F3I6,Philips,B,2,28,8\nF1K2S9,GE,C,4,23,7\nF2H5S1,Siemens,A,1,24,10\nF3G5K5,Siemens,A,1,24,12\nF3L0M1,GE,C,4,0,0\nF4K3S1,Philips,B,3,24,8\nF5I1Z8,GE,C,4,23,8\nF5I9Q2,Philips,B,2,29,11\nF6J9L9,Canon,D,5,29,10\nF8N2S1,Philips,B,3,24,8\nF9M1R2,Canon,D,5,29,11\nG0H4J3,Philips,B,2,29,10\nG0I6P3,Philips,B,3,29,12\nG0I7Z6,GE,C,4,0,0\nG1J5K3,Canon,D,5,24,13\nG1K1V3,Philips,B,2,29,11\nG1N6S7,Philips,B,2,0,12\nG2J1M5,Philips,B,2,29,9\nG2M7W4,Philips,B,3,24,9\nG2O2S6,Philips,B,2,29,10\nG3M5S4,Philips,B,3,29,10\nG4I7V2,Canon,D,5,33,9\nG4K8P3,GE,C,4,0,0\nG4L8Z7,Philips,B,2,29,8\nG4S9U3,Siemens,A,1,0,11\nG4U3U8,GE,C,4,0,0\nG5P4U3,Siemens,A,1,0,9\nG6T0Z6,GE,C,4,0,0\nG7I5V7,Siemens,A,1,0,10\nG7N8R7,Philips,B,2,28,9\nG7Q2W0,Philips,B,2,0,8\nG7S6V0,Siemens,A,1,0,8\nG8K0M3,Canon,D,5,25,10\nG8L0Z0,Canon,D,5,3,10\nG8N2U5,Philips,B,3,23,9\nG8R0Z9,GE,C,4,21,8\nG9L0O9,Philips,B,2,29,10\nG9N5V9,Siemens,A,1,0,11\nH0K3Q4,Philips,B,3,24,9\nH1I3W0,Siemens,A,1,0,9\nH1J5W8,Philips,B,2,3,15\nH1M5Y6,Philips,B,2,29,8\nH1N7P7,GE,C,4,0,0\nH1N8S6,Siemens,A,1,0,10\nH1W2Y1,Philips,B,2,29,10\nH2M9S1,GE,C,4,23,8\nH3R6S9,GE,C,4,24,9\nH3U1Y1,Philips,B,3,24,8\nH4I2T8,Philips,B,3,24,8\nH5N0P0,Siemens,A,1,0,8\nH6I0I6,Philips,B,2,28,10\nH6P7T1,Canon,D,5,25,12\nH7I4J3,Philips,B,3,24,8\nH7K5U5,Philips,B,3,24,8\nH7L8R8,GE,C,4,23,9\nH7N4V9,Philips,B,2,0,12\nH7P5Z4,GE,C,4,24,9\nH8K2K7,Philips,B,3,23,7\nH9J6L5,GE,C,4,0,0\nI0I2J8,Philips,B,2,21,7\nI0J5U3,Philips,B,2,28,10\nI2J6Z6,Siemens,A,1,0,7\nI2K2Y8,Philips,B,2,29,8\nI4J8P4,GE,C,4,0,0\nI4L4V7,Canon,D,5,22,9\nI4R8V6,GE,C,4,24,8\nI5L3S2,Siemens,A,1,0,10\nI6N3P3,Philips,B,2,29,10\nI6P4R0,GE,C,4,18,11\nI6T4W8,Philips,B,2,28,12\nI7T3U1,Siemens,A,1,0,13\nI7W4Y8,Philips,B,2,27,10\nI8N8Y1,Canon,D,5,33,14\nI8Z0Z6,GE,C,4,24,8\nJ1T9Y1,Siemens,A,1,0,12\nJ2S5T1,GE,C,4,0,0\nJ4J8Q3,Siemens,A,1,0,8\nJ4J9W6,Philips,B,2,29,11\nJ6K4V3,GE,C,4,21,3\nJ6K6P5,Siemens,A,1,0,9\nJ6M5O2,Canon,D,5,33,11\nJ6P5T8,Siemens,A,1,0,9\nJ8R5W2,Siemens,A,1,0,9\nJ9L4S2,Philips,B,3,29,13\nJ9L6N9,Philips,B,2,29,9\nK2S1U6,Philips,B,2,29,12\nK3P3Y6,Canon,D,5,32,11\nK3R0Y7,GE,C,4,28,9\nK4T7Y0,Siemens,A,1,24,12\nK5K6N1,Philips,B,3,24,9\nK5L2U3,Siemens,A,1,0,7\nK5L4S1,Canon,D,5,27,10\nK5M7V5,Canon,D,5,23,10\nK5P0Y1,Siemens,A,1,0,9\nK6N4N7,Siemens,A,1,0,8\nK7L2Y6,Philips,B,2,29,10\nK7N0R7,Canon,D,5,29,9\nK7O3Q0,GE,C,4,23,10\nK9N0W0,Philips,B,2,29,12\nL1Q1Z5,Philips,B,2,0,15\nL1Q9V8,Siemens,A,1,0,10\nL2V5Z0,GE,C,4,3,17\nL4Q2U3,Siemens,A,1,0,8\nL5Q6T7,Philips,B,2,28,11\nL5U7Y4,Canon,D,5,33,10\nL6T2T5,Canon,D,5,1,18\nL7Y7Z2,Philips,B,2,0,11\nL8M2U8,Canon,D,5,0,11\nL8N7P0,GE,C,4,23,9\nL8N7Z0,Philips,B,3,29,9\nM0P8U8,Philips,B,3,0,9\nM1R4S1,Siemens,A,1,0,9\nM2P1R1,Siemens,A,1,0,10\nM2P5T8,Canon,D,5,0,11\nM4P7Q6,Philips,B,2,2,13\nM4T4V6,Philips,B,2,29,12\nM6M9N1,GE,C,4,24,11\nM6V2Y0,Philips,B,3,0,9\nM8N3Z3,GE,C,4,0,0\nN1P8Q9,Philips,B,2,29,10\nN1S7Z2,GE,C,4,0,0\nN2O7U5,Canon,D,5,18,7\nN5S7Y1,Siemens,A,1,0,9\nN7P3T8,Canon,D,5,25,10\nN7V9W9,Philips,B,3,24,7\nN7W6Z8,Canon,D,5,23,11\nN8N9U0,Siemens,A,1,0,11\nN9P5Z0,Philips,B,3,0,13\nN9Q4T8,Canon,D,5,32,10\nO0S9V7,Siemens,A,1,0,9\nO1O9Y6,GE,C,4,0,0\nO3R8Y5,Philips,B,2,0,8\nO4O6U5,GE,C,4,24,8\nO4T6Y7,Philips,B,2,0,13\nO5U2U7,Canon,D,5,1,9\nO7Q7U3,Canon,D,5,29,11\nO9V8W5,Philips,B,2,29,12\nP0S5Y0,Siemens,A,1,0,10\nP3P9S5,Siemens,A,1,0,10\nP3R6Y5,Philips,B,3,24,9\nP3T5U1,GE,C,4,27,9\nP5R1Y4,Siemens,A,1,0,9\nP6U0Y0,Philips,B,2,29,10\nP8V0Y7,Canon,D,5,27,9\nP8W4Z0,Philips,B,3,0,12\nP9S7W2,Philips,B,3,23,10\nQ0Q1Y4,GE,C,4,0,0\nQ0U0V5,Siemens,A,1,0,10\nQ1Q3T1,GE,C,4,28,10\nQ3Q6R8,Philips,B,3,24,9\nQ3R9W7,Siemens,A,1,0,10\nQ4W5Z8,Canon,D,5,33,10\nQ5V8W3,GE,C,4,24,10\nQ7V1Y5,Philips,B,2,0,11\nR1R6Y8,GE,C,4,23,10\nR2R7Z5,GE,C,4,23,8\nR3V5W7,Canon,D,5,27,13\nR4Y1Z9,Siemens,A,1,24,7\nR6V5W3,GE,C,4,28,10\nR8V0Y4,GE,C,4,24,7\nS1S3Z7,Siemens,A,1,0,9\nT2T9Z9,Siemens,A,1,0,13\nT2Z1Z9,Canon,D,5,29,9\nT9U9W2,Siemens,A,1,0,10\nV4W8Z5,GE,C,4,19,9\nW5Z4Z8,Philips,B,2,29,11\nY6Y9Z2,Philips,B,3,29,9\n"
  },
  {
    "path": "dataset_domain.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.utils.data import Dataset, DataLoader\nimport pandas as pd\nimport SimpleITK as sitk\nfrom skimage.measure import label, regionprops\nimport math\nimport pdb\n\nclass CMRDataset(Dataset):\n    def __init__(self, dataset_dir, mode='train', domain='A', crop_size=256, scale=0.1, rotate=10, debug=False):\n\n        self.mode = mode\n        self.dataset_dir = dataset_dir\n        self.crop_size = crop_size\n        self.scale = scale\n        self.rotate = rotate\n\n        if self.mode == 'train':\n            pre_face = 'Training'\n            if 'C' in domain or 'D' in domain:\n                print('No domain C or D in Training set')\n                raise StandardError\n\n        elif self.mode == 'test':\n            pre_face = 'Testing'\n\n        else:\n            print('Wrong mode')\n            raise StandardError\n        if debug:\n            # validation set is the smallest, need the shortest time for load data.\n           pre_face = 'Testing'\n\n        path = self.dataset_dir + pre_face + '/'\n        print('start loading data')\n        \n        name_list = []\n\n        if 'A' in domain:\n            df = pd.read_csv(self.dataset_dir+pre_face+'_A.csv')\n            name_list += np.array(df['name']).tolist()\n        if 'B' in domain:\n            df = pd.read_csv(self.dataset_dir+pre_face+'_B.csv')\n            name_list += np.array(df['name']).tolist()\n        if 'C' in domain:\n            df = pd.read_csv(self.dataset_dir+pre_face+'_C.csv')\n            name_list += np.array(df['name']).tolist()\n        if 'D' in domain:\n            df = pd.read_csv(self.dataset_dir+pre_face+'_D.csv')\n            name_list += np.array(df['name']).tolist()\n\n\n\n        \n        img_list = []\n        lab_list = []\n        spacing_list = []\n\n        for name in name_list:\n            for name_idx in os.listdir(path+name):\n                if 'gt' in name_idx:\n                    continue\n                else:\n                    idx = name_idx.split('_')[2].split('.')[0]\n                    \n                    itk_img = sitk.ReadImage(path+name+'/%s_sa_%s.nii.gz'%(name, idx))\n                    itk_lab = sitk.ReadImage(path+name+'/%s_sa_gt_%s.nii.gz'%(name, idx))\n                    \n                    spacing = np.array(itk_lab.GetSpacing()).tolist()\n                    spacing_list.append(spacing[::-1])\n\n                    assert itk_img.GetSize() == itk_lab.GetSize()\n                    img, lab = self.preprocess(itk_img, itk_lab)\n\n                    img_list.append(img)\n                    lab_list.append(lab)\n\n       \n        self.img_slice_list = []\n        self.lab_slice_list = []\n        if self.mode == 'train':\n            for i in range(len(img_list)):\n                tmp_img = img_list[i]\n                tmp_lab = lab_list[i]\n\n                z, x, y = tmp_img.shape\n\n                for j in range(z):\n                    self.img_slice_list.append(tmp_img[j])\n                    self.lab_slice_list.append(tmp_lab[j])\n\n        else:\n            self.img_slice_list = img_list\n            self.lab_slice_list = lab_list\n            self.spacing_list = spacing_list\n\n        print('load done, length of dataset:', len(self.img_slice_list))\n        \n    def __len__(self):\n        return len(self.img_slice_list)\n\n    def preprocess(self, itk_img, itk_lab):\n        \n        img = sitk.GetArrayFromImage(itk_img)\n        lab = sitk.GetArrayFromImage(itk_lab)\n\n        max98 = np.percentile(img, 98)\n        img = np.clip(img, 0, max98)\n            \n        z, y, x = img.shape\n        if x < self.crop_size:\n            diff = (self.crop_size + 10 - x) // 2\n            img = np.pad(img, ((0,0), (0,0), (diff, diff)))\n            lab = np.pad(lab, ((0,0), (0,0), (diff,diff)))\n        if y < self.crop_size:\n            diff = (self.crop_size + 10 -y) // 2\n            img = np.pad(img, ((0,0), (diff, diff), (0,0)))\n            lab = np.pad(lab, ((0,0), (diff, diff), (0,0)))\n\n        img = img / max98\n        tensor_img = torch.from_numpy(img).float()\n        tensor_lab = torch.from_numpy(lab).long()\n\n        return tensor_img, tensor_lab\n\n\n    def __getitem__(self, idx):\n        tensor_image = self.img_slice_list[idx]\n        tensor_label = self.lab_slice_list[idx]\n       \n        if self.mode == 'train':\n            tensor_image = tensor_image.unsqueeze(0).unsqueeze(0)\n            tensor_label = tensor_label.unsqueeze(0).unsqueeze(0)\n            \n            # Gaussian Noise\n            tensor_image += torch.randn(tensor_image.shape) * 0.02\n            # Additive brightness\n            rnd_bn = np.random.normal(0, 0.7)#0.03\n            tensor_image += rnd_bn\n            # gamma\n            minm = tensor_image.min()\n            rng = tensor_image.max() - minm\n            gamma = np.random.uniform(0.5, 1.6)\n            tensor_image = torch.pow((tensor_image-minm)/rng, gamma)*rng + minm\n\n            tensor_image, tensor_label = self.random_zoom_rotate(tensor_image, tensor_label)\n            tensor_image, tensor_label = self.randcrop(tensor_image, tensor_label)\n        else:\n            tensor_image, tensor_label = self.center_crop(tensor_image, tensor_label)\n        \n        assert tensor_image.shape == tensor_label.shape\n        \n        if self.mode == 'train':\n            return tensor_image, tensor_label\n        else:\n            return tensor_image, tensor_label, np.array(self.spacing_list[idx])\n\n    def randcrop(self, img, label):\n        _, _, H, W = img.shape\n        \n        diff_H = H - self.crop_size\n        diff_W = W - self.crop_size\n        \n        rand_x = np.random.randint(0, diff_H)\n        rand_y = np.random.randint(0, diff_W)\n        \n        croped_img = img[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]\n        croped_lab = label[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]\n\n        return croped_img, croped_lab\n\n\n    def center_crop(self, img, label):\n        D, H, W = img.shape\n        \n        diff_H = H - self.crop_size\n        diff_W = W - self.crop_size\n        \n        rand_x = diff_H // 2\n        rand_y = diff_W // 2\n\n        croped_img = img[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]\n        croped_lab = label[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]\n\n        return croped_img, croped_lab\n\n    def random_zoom_rotate(self, img, label):\n        scale_x = np.random.random() * 2 * self.scale + (1 - self.scale)\n        scale_y = np.random.random() * 2 * self.scale + (1 - self.scale)\n\n\n        theta_scale = torch.tensor([[scale_x, 0, 0],\n                                    [0, scale_y, 0],\n                                    [0, 0, 1]]).float()\n        angle = (float(np.random.randint(-self.rotate, self.rotate)) / 180.) * math.pi\n\n        theta_rotate = torch.tensor( [  [math.cos(angle), -math.sin(angle), 0], \n                                        [math.sin(angle), math.cos(angle), 0], \n                                        ]).float()\n        \n    \n        theta_rotate = theta_rotate.unsqueeze(0)\n        grid = F.affine_grid(theta_rotate, img.size())\n        img = F.grid_sample(img, grid, mode='bilinear')\n        label = F.grid_sample(label.float(), grid, mode='nearest').long()\n    \n        return img, label\n\n\n"
  },
  {
    "path": "losses.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport pdb \n\nclass DiceLoss(nn.Module):\n\n    def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):\n        super(DiceLoss, self).__init__()\n        self.alpha = alpha\n        self.beta = beta\n\n        self.size_average = size_average\n        self.reduce = reduce\n\n    def forward(self, preds, targets):\n        N = preds.size(0)\n        C = preds.size(1)\n        \n\n        P = F.softmax(preds, dim=1)\n        smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001)\n\n        class_mask = torch.zeros(preds.shape).to(preds.device)\n        class_mask.scatter_(1, targets, 1.) \n\n        ones = torch.ones(preds.shape).to(preds.device)\n        P_ = ones - P \n        class_mask_ = ones - class_mask\n\n        TP = P * class_mask\n        FP = P * class_mask_\n        FN = P_ * class_mask\n\n        smooth = smooth.to(preds.device)\n        self.alpha = FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) / ((FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) + FN.transpose(0, 1).reshape(C, -1).sum(dim=(1))) + smooth)\n    \n        self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8) \n        #print('alpha:', self.alpha)\n        self.beta = 1 - self.alpha\n        num = torch.sum(TP.transpose(0, 1).reshape(C, -1), dim=(1)).float()\n        den = num + self.alpha * torch.sum(FP.transpose(0, 1).reshape(C, -1), dim=(1)).float() + self.beta * torch.sum(FN.transpose(0, 1).reshape(C, -1), dim=(1)).float()\n\n        dice = num / (den + smooth)\n\n        if not self.reduce:\n            loss = torch.ones(C).to(dice.device) - dice\n            return loss\n\n        loss = 1 - dice\n        loss = loss.sum()\n\n        if self.size_average:\n            loss /= C\n\n        return loss\n\nclass FocalLoss(nn.Module):\n    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):\n        super(FocalLoss, self).__init__()\n\n        if alpha is None:\n            self.alpha = torch.ones(class_num)\n        else:\n            self.alpha = alpha\n\n        self.gamma = gamma\n        self.size_average = size_average\n\n    def forward(self, preds, targets):\n        N = preds.size(0)\n        C = preds.size(1)\n\n        targets = targets.unsqueeze(1)\n        P = F.softmax(preds, dim=1)\n        log_P = F.log_softmax(preds, dim=1)\n\n        class_mask = torch.zeros(preds.shape).to(preds.device)\n        class_mask.scatter_(1, targets, 1.)\n        \n        if targets.size(1) == 1:\n            # squeeze the chaneel for target\n            targets = targets.squeeze(1)\n        alpha = self.alpha[targets.data].to(preds.device)\n\n        probs = (P * class_mask).sum(1)\n        log_probs = (log_P * class_mask).sum(1)\n        \n        batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs\n\n        if self.size_average:\n            loss = batch_loss.mean()\n        else:\n            loss = batch_loss.sum()\n\n        return loss\n\nif __name__ == '__main__':\n    \n    DL = DiceLoss()\n    FL = FocalLoss(10)\n    \n    pred = torch.randn(2, 10, 128, 128)\n    target = torch.zeros((2, 1, 128, 128)).long()\n\n    dl_loss = DL(pred, target)\n    fl_loss = FL(pred, target)\n\n    print('2D:', dl_loss.item(), fl_loss.item())\n\n    pred = torch.randn(2, 10, 64, 128, 128)\n    target = torch.zeros(2, 1, 64, 128, 128).long()\n\n    dl_loss = DL(pred, target)\n    fl_loss = FL(pred, target)\n\n    print('3D:', dl_loss.item(), fl_loss.item())\n\n    \n"
  },
  {
    "path": "model/__init__.py",
    "content": ""
  },
  {
    "path": "model/conv_trans_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport pdb\n\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)\n\nclass depthwise_separable_conv(nn.Module):\n    def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False):\n        super().__init__()\n        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride)\n        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)\n\n    def forward(self, x): \n        out = self.depthwise(x)\n        out = self.pointwise(out)\n\n        return out \n\nclass Mlp(nn.Module):\n    def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_ch = out_ch or in_ch\n        hid_ch = hid_ch or in_ch\n\n        self.fc1 = nn.Conv2d(in_ch, hid_ch, kernel_size=1)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hid_ch, out_ch, kernel_size=1)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n\n        return x\n\nclass BasicBlock(nn.Module):\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(inplanes)\n        self.relu = nn.ReLU(inplace=True)\n\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or inplanes != planes:\n            self.shortcut = nn.Sequential(\n                    nn.BatchNorm2d(inplanes),\n                    self.relu, \n                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)\n                    )\n\n    def forward(self, x): \n        residue = x\n\n        out = self.bn1(x)\n        out = self.relu(out)\n        out = self.conv1(out)\n\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        out += self.shortcut(residue)\n\n        return out\n\nclass BasicTransBlock(nn.Module):\n\n    def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n        super().__init__()\n        self.bn1 = nn.BatchNorm2d(in_ch)\n\n        self.attn = LinearAttention(in_ch, heads=heads, dim_head=in_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        \n        self.bn2 = nn.BatchNorm2d(in_ch)\n        self.relu = nn.ReLU(inplace=True)\n        self.mlp = nn.Conv2d(in_ch, in_ch, kernel_size=1, bias=False)\n        # conv1x1 has not difference with mlp in performance\n\n    def forward(self, x):\n\n        out = self.bn1(x)\n        out, q_k_attn = self.attn(out)\n        \n        out = out + x\n        residue = out\n\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.mlp(out)\n\n        out += residue\n\n        return out\n\nclass BasicTransDecoderBlock(nn.Module):\n\n    def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n        super().__init__()\n\n        self.bn_l = nn.BatchNorm2d(in_ch)\n        self.bn_h = nn.BatchNorm2d(out_ch)\n\n        self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)\n        self.attn = LinearAttentionDecoder(in_ch, out_ch, heads=heads, dim_head=out_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        \n        self.bn2 = nn.BatchNorm2d(out_ch)\n        self.relu = nn.ReLU(inplace=True)\n        self.mlp = nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)\n\n    def forward(self, x1, x2):\n\n        residue = F.interpolate(self.conv_ch(x1), size=x2.shape[-2:], mode='bilinear', align_corners=True)\n        #x1: low-res, x2: high-res\n        x1 = self.bn_l(x1)\n        x2 = self.bn_h(x2)\n\n        out, q_k_attn = self.attn(x2, x1)\n        \n        out = out + residue\n        residue = out\n\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.mlp(out)\n\n        out += residue\n\n        return out\n\n\n\n\n########################################################################\n# Transformer components\n\nclass LinearAttention(nn.Module):\n    \n    def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n        super().__init__()\n\n        self.inner_dim = dim_head * heads\n        self.heads = heads\n        self.scale = dim_head ** (-0.5)\n        self.dim_head = dim_head\n        self.reduce_size = reduce_size\n        self.projection = projection\n        self.rel_pos = rel_pos\n        \n        # depthwise conv is slightly better than conv1x1\n        #self.to_qkv = nn.Conv2d(dim, self.inner_dim*3, kernel_size=1, stride=1, padding=0, bias=True)\n        #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)\n       \n        self.to_qkv = depthwise_separable_conv(dim, self.inner_dim*3)\n        self.to_out = depthwise_separable_conv(self.inner_dim, dim)\n\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        if self.rel_pos:\n            # 2D input-independent relative position encoding is a little bit better than\n            # 1D input-denpendent counterpart\n            self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)\n            #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)\n\n    def forward(self, x):\n\n        B, C, H, W = x.shape\n\n        #B, inner_dim, H, W\n        qkv = self.to_qkv(x)\n        q, k, v = qkv.chunk(3, dim=1)\n\n        if self.projection == 'interp' and H != self.reduce_size:\n            k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))\n\n        elif self.projection == 'maxpool' and H != self.reduce_size:\n            k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))\n        \n        q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=H, w=W)\n        k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))\n\n        q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)\n        \n        if self.rel_pos:\n            relative_position_bias = self.relative_position_encoding(H, W)\n            q_k_attn += relative_position_bias\n            #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, H, W, self.dim_head)\n            #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w\n\n        q_k_attn *= self.scale\n        q_k_attn = F.softmax(q_k_attn, dim=-1)\n        q_k_attn = self.attn_drop(q_k_attn)\n\n        out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)\n        out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads)\n\n        out = self.to_out(out)\n        out = self.proj_drop(out)\n\n        return out, q_k_attn\n\nclass LinearAttentionDecoder(nn.Module):\n    \n    def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n        super().__init__()\n\n        self.inner_dim = dim_head * heads\n        self.heads = heads\n        self.scale = dim_head ** (-0.5)\n        self.dim_head = dim_head\n        self.reduce_size = reduce_size\n        self.projection = projection\n        self.rel_pos = rel_pos\n        \n        # depthwise conv is slightly better than conv1x1\n        #self.to_kv = nn.Conv2d(dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=True)\n        #self.to_q = nn.Conv2d(dim, self.inner_dim, kernel_size=1, stride=1, padding=0, bias=True)\n        #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)\n       \n        self.to_kv = depthwise_separable_conv(in_dim, self.inner_dim*2)\n        self.to_q = depthwise_separable_conv(out_dim, self.inner_dim)\n        self.to_out = depthwise_separable_conv(self.inner_dim, out_dim)\n\n\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        if self.rel_pos:\n            self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)\n            #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)\n\n    def forward(self, q, x):\n\n        B, C, H, W = x.shape # low-res feature shape\n        BH, CH, HH, WH = q.shape # high-res feature shape\n\n        k, v = self.to_kv(x).chunk(2, dim=1) #B, inner_dim, H, W\n        q = self.to_q(q) #BH, inner_dim, HH, WH\n\n        if self.projection == 'interp' and H != self.reduce_size:\n            k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))\n\n        elif self.projection == 'maxpool' and H != self.reduce_size:\n            k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))\n        \n        q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=HH, w=WH)\n        k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))\n\n        q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)\n        \n        if self.rel_pos:\n            relative_position_bias = self.relative_position_encoding(HH, WH)\n            q_k_attn += relative_position_bias\n            #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, HH, WH, self.dim_head)\n            #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w\n       \n        q_k_attn *= self.scale\n        q_k_attn = F.softmax(q_k_attn, dim=-1)\n        q_k_attn = self.attn_drop(q_k_attn)\n\n        out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)\n        out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=HH, w=WH, dim_head=self.dim_head, heads=self.heads)\n\n        out = self.to_out(out)\n        out = self.proj_drop(out)\n\n        return out, q_k_attn\n\nclass RelativePositionEmbedding(nn.Module):\n    # input-dependent relative position\n    def __init__(self, dim, shape):\n        super().__init__()\n\n        self.dim = dim\n        self.shape = shape\n\n        self.key_rel_w = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)\n        self.key_rel_h = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)\n\n        coords = torch.arange(self.shape)\n        relative_coords = coords[None, :] - coords[:, None] # h, h\n        relative_coords += self.shape - 1 # shift to start from 0\n        \n        self.register_buffer('relative_position_index', relative_coords)\n\n\n\n    def forward(self, q, Nh, H, W, dim_head):\n        # q: B, Nh, HW, dim\n        B, _, _, dim = q.shape\n\n        # q: B, Nh, H, W, dim_head\n        q = rearrange(q, 'b heads (h w) dim_head -> b heads h w dim_head', b=B, dim_head=dim_head, heads=Nh, h=H, w=W)\n\n        rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, 'w')\n\n        rel_logits_h = self.relative_logits_1d(q.permute(0, 1, 3, 2, 4), self.key_rel_h, 'h')\n\n        return rel_logits_w, rel_logits_h\n\n    def relative_logits_1d(self, q, rel_k, case):\n        \n        B, Nh, H, W, dim = q.shape\n\n        rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) # B, Nh, H, W, 2*shape-1\n\n        if W != self.shape:\n            # self_relative_position_index origin shape: w, w\n            # after repeat: W, w\n            relative_index= torch.repeat_interleave(self.relative_position_index, W//self.shape, dim=0) # W, shape\n        relative_index = relative_index.view(1, 1, 1, W, self.shape)\n        relative_index = relative_index.repeat(B, Nh, H, 1, 1)\n\n        rel_logits = torch.gather(rel_logits, 4, relative_index) # B, Nh, H, W, shape\n        rel_logits = rel_logits.unsqueeze(3)\n        rel_logits = rel_logits.repeat(1, 1, 1, self.shape, 1, 1)\n\n        if case == 'w':\n            rel_logits = rearrange(rel_logits, 'b heads H h W w -> b heads (H W) (h w)')\n\n        elif case == 'h':\n            rel_logits = rearrange(rel_logits, 'b heads W w H h -> b heads (H W) (h w)')\n\n        return rel_logits\n\n\n\n\nclass RelativePositionBias(nn.Module):\n    # input-independent relative position attention\n    # As the number of parameters is smaller, so use 2D here\n    # Borrowed some code from SwinTransformer: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py\n    def __init__(self, num_heads, h, w):\n        super().__init__()\n        self.num_heads = num_heads\n        self.h = h\n        self.w = w\n\n        self.relative_position_bias_table = nn.Parameter(\n                torch.randn((2*h-1) * (2*w-1), num_heads)*0.02)\n\n        coords_h = torch.arange(self.h)\n        coords_w = torch.arange(self.w)\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w\n        coords_flatten = torch.flatten(coords, 1) # 2, hw\n\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()\n        relative_coords[:, :, 0] += self.h - 1\n        relative_coords[:, :, 1] += self.w - 1\n        relative_coords[:, :, 0] *= 2 * self.h - 1\n        relative_position_index = relative_coords.sum(-1) # hw, hw\n    \n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n    def forward(self, H, W):\n        \n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.h, self.w, self.h*self.w, -1) #h, w, hw, nH\n        relative_position_bias_expand_h = torch.repeat_interleave(relative_position_bias, H//self.h, dim=0)\n        relative_position_bias_expanded = torch.repeat_interleave(relative_position_bias_expand_h, W//self.w, dim=1) #HW, hw, nH\n        \n        relative_position_bias_expanded = relative_position_bias_expanded.view(H*W, self.h*self.w, self.num_heads).permute(2, 0, 1).contiguous().unsqueeze(0)\n\n        return relative_position_bias_expanded\n\n\n###########################################################################\n# Unet Transformer building block\n\nclass down_block_trans(nn.Module):\n    def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool=True, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n\n        super().__init__()\n\n        block_list = []\n\n        if bottleneck:\n            block = BottleneckBlock\n        else:\n            block = BasicBlock\n\n        attn_block = BasicTransBlock\n\n        if maxpool:\n            block_list.append(nn.MaxPool2d(2))\n            block_list.append(block(in_ch, out_ch, stride=1))\n        else:\n            block_list.append(block(in_ch, out_ch, stride=2))\n        \n        assert num_block > 0\n        for i in range(num_block):\n            block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))\n        self.blocks = nn.Sequential(*block_list)\n\n        \n    def forward(self, x):\n        \n        out = self.blocks(x)\n\n\n        return out\n\nclass up_block_trans(nn.Module):\n    def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n        super().__init__()\n \n        self.attn_decoder = BasicTransDecoderBlock(in_ch, out_ch, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        if bottleneck:\n            block = BottleneckBlock\n        else:\n            block = BasicBlock\n        attn_block = BasicTransBlock\n        \n        block_list = []\n\n        for i in range(num_block):\n            block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))\n\n        block_list.append(block(2*out_ch, out_ch, stride=1))\n\n        self.blocks = nn.Sequential(*block_list)\n\n    def forward(self, x1, x2):\n        # x1: low-res feature, x2: high-res feature\n        out = self.attn_decoder(x1, x2)\n        out = torch.cat([out, x2], dim=1)\n        out = self.blocks(out)\n\n        return out\n\nclass block_trans(nn.Module):\n    def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):\n\n        super().__init__()\n\n        block_list = []\n\n        attn_block = BasicTransBlock\n\n        assert num_block > 0\n        for i in range(num_block):\n            block_list.append(attn_block(in_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))\n        self.blocks = nn.Sequential(*block_list)\n\n        \n    def forward(self, x):\n        \n        out = self.blocks(x)\n\n\n        return out\n\n"
  },
  {
    "path": "model/resnet_utnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .unet_utils import up_block\nfrom .transunet import ResNetV2\nfrom .conv_trans_utils import block_trans\nimport pdb\n\n\n\n\nclass ResNet_UTNet(nn.Module):\n    def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', num_blocks=[1,2,4], projection='interp', num_heads=[4,4,4], attn_drop=0., proj_drop=0., rel_pos=True, block_units=(3,4,9), width_factor=1):\n        \n        super().__init__()\n        self.resnet = ResNetV2(block_units, width_factor)\n\n\n        if '0' in block_list:\n            self.trans_0 = block_trans(64, num_blocks[-4], 64//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.trans_0 = nn.Identity()\n\n\n        if '1' in block_list:\n            self.trans_1 = block_trans(256, num_blocks[-3], 256//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        else:\n            self.trans_1 = nn.Identity()\n\n        if '2' in block_list:\n            self.trans_2 =  block_trans(512, num_blocks[-2], 512//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        else:\n            self.trans_2 = nn.Identity()\n\n        if '3' in block_list:\n            self.trans_3 = block_trans(1024, num_blocks[-1], 1024//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        else:\n            self.trans_3 = nn.Identity()\n        \n        self.up1 = up_block(1024, 512, scale=(2,2), num_block=1)\n        self.up2 = up_block(512, 256, scale=(2,2), num_block=1)\n        self.up3 = up_block(256, 64, scale=(2,2), num_block=1)\n        self.up4 = nn.UpsamplingBilinear2d(scale_factor=2)\n\n        self.output = nn.Conv2d(64, num_class, kernel_size=3, padding=1, bias=True)\n\n    def forward(self, x):\n        if x.shape[1] == 1: \n            x = x.repeat(1, 3, 1, 1)\n        x, features = self.resnet(x)\n\n        out3 = self.trans_3(x)\n        out2 = self.trans_2(features[0])\n        out1 = self.trans_1(features[1])\n        out0 = self.trans_0(features[2])\n\n        out = self.up1(out3, out2)\n        out = self.up2(out, out1)\n        out = self.up3(out, out0)\n        out = self.up4(out)\n\n        out = self.output(out)\n\n        return out\n            \n\n\n"
  },
  {
    "path": "model/swin_unet.py",
    "content": "# code is borrowed from the original repo and fit into our training framework\n# https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912\n\n\nfrom __future__ import absolute_import\n\nfrom __future__ import division\n\nfrom __future__ import print_function\n\n\nimport torch\n\nimport torch.nn as nn\n\nimport torch.utils.checkpoint as checkpoint\n\nfrom einops import rearrange\n\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\n\n\n\nimport copy\n\nimport logging\n\nimport math\n\n\nfrom os.path import join as pjoin\n\nimport numpy as np\n\n\n\nfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm\n\nfrom torch.nn.modules.utils import _pair\n\nfrom scipy import ndimage\n\n\n\n\n\nclass Mlp(nn.Module):\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n\n        super().__init__()\n\n        out_features = out_features or in_features\n\n        hidden_features = hidden_features or in_features\n\n        self.fc1 = nn.Linear(in_features, hidden_features)\n\n        self.act = act_layer()\n\n        self.fc2 = nn.Linear(hidden_features, out_features)\n\n        self.drop = nn.Dropout(drop)\n\n\n\n    def forward(self, x):\n\n        x = self.fc1(x)\n\n        x = self.act(x)\n\n        x = self.drop(x)\n\n        x = self.fc2(x)\n\n        x = self.drop(x)\n\n        return x\n\n\n\n\n\ndef window_partition(x, window_size):\n\n    \"\"\"\n\n    Args:\n\n        x: (B, H, W, C)\n\n        window_size (int): window size\n\n\n\n    Returns:\n\n        windows: (num_windows*B, window_size, window_size, C)\n\n    \"\"\"\n\n    B, H, W, C = x.shape\n\n    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n\n    return windows\n\n\n\n\n\ndef window_reverse(windows, window_size, H, W):\n\n    \"\"\"\n\n    Args:\n\n        windows: (num_windows*B, window_size, window_size, C)\n\n        window_size (int): Window size\n\n        H (int): Height of image\n\n        W (int): Width of image\n\n\n\n    Returns:\n\n        x: (B, H, W, C)\n\n    \"\"\"\n\n    B = int(windows.shape[0] / (H * W / window_size / window_size))\n\n    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n\n    return x\n\n\n\n\n\nclass WindowAttention(nn.Module):\n\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n\n    It supports both of shifted and non-shifted window.\n\n\n\n    Args:\n\n        dim (int): Number of input channels.\n\n        window_size (tuple[int]): The height and width of the window.\n\n        num_heads (int): Number of attention heads.\n\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n\n    \"\"\"\n\n\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n\n\n        super().__init__()\n\n        self.dim = dim\n\n        self.window_size = window_size  # Wh, Ww\n\n        self.num_heads = num_heads\n\n        head_dim = dim // num_heads\n\n        self.scale = qk_scale or head_dim ** -0.5\n\n\n\n        # define a parameter table of relative position bias\n\n        self.relative_position_bias_table = nn.Parameter(\n\n            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH\n\n\n\n        # get pair-wise relative position index for each token inside the window\n\n        coords_h = torch.arange(self.window_size[0])\n\n        coords_w = torch.arange(self.window_size[1])\n\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n\n        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww\n\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2\n\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n\n        self.attn_drop = nn.Dropout(attn_drop)\n\n        self.proj = nn.Linear(dim, dim)\n\n        self.proj_drop = nn.Dropout(proj_drop)\n\n\n\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n\n        self.softmax = nn.Softmax(dim=-1)\n\n\n\n    def forward(self, x, mask=None):\n\n        \"\"\"\n\n        Args:\n\n            x: input features with shape of (num_windows*B, N, C)\n\n            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n\n        \"\"\"\n\n        B_, N, C = x.shape\n\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n\n\n        q = q * self.scale\n\n        attn = (q @ k.transpose(-2, -1))\n\n\n\n        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n\n            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH\n\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww\n\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n\n\n        if mask is not None:\n\n            nW = mask.shape[0]\n\n            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n\n            attn = attn.view(-1, self.num_heads, N, N)\n\n            attn = self.softmax(attn)\n\n        else:\n\n            attn = self.softmax(attn)\n\n\n\n        attn = self.attn_drop(attn)\n\n\n\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n\n        x = self.proj(x)\n\n        x = self.proj_drop(x)\n\n        return x\n\n\n\n    def extra_repr(self) -> str:\n\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n\n\n    def flops(self, N):\n\n        # calculate flops for 1 window with token length of N\n\n        flops = 0\n\n        # qkv = self.qkv(x)\n\n        flops += N * self.dim * 3 * self.dim\n\n        # attn = (q @ k.transpose(-2, -1))\n\n        flops += self.num_heads * N * (self.dim // self.num_heads) * N\n\n        #  x = (attn @ v)\n\n        flops += self.num_heads * N * N * (self.dim // self.num_heads)\n\n        # x = self.proj(x)\n\n        flops += N * self.dim * self.dim\n\n        return flops\n\n\n\n\n\nclass SwinTransformerBlock(nn.Module):\n\n    r\"\"\" Swin Transformer Block.\n\n\n\n    Args:\n\n        dim (int): Number of input channels.\n\n        input_resolution (tuple[int]): Input resulotion.\n\n        num_heads (int): Number of attention heads.\n\n        window_size (int): Window size.\n\n        shift_size (int): Shift size for SW-MSA.\n\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n\n        drop (float, optional): Dropout rate. Default: 0.0\n\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n\n    \"\"\"\n\n\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n\n        super().__init__()\n\n        self.dim = dim\n\n        self.input_resolution = input_resolution\n\n        self.num_heads = num_heads\n\n        self.window_size = window_size\n\n        self.shift_size = shift_size\n\n        self.mlp_ratio = mlp_ratio\n\n        if min(self.input_resolution) <= self.window_size:\n\n            # if window size is larger than input resolution, we don't partition windows\n\n            self.shift_size = 0\n\n            self.window_size = min(self.input_resolution)\n\n        assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n\n\n\n        self.norm1 = norm_layer(dim)\n\n        self.attn = WindowAttention(\n\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n\n        mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n\n\n        if self.shift_size > 0:\n\n            # calculate attention mask for SW-MSA\n\n            H, W = self.input_resolution\n\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n\n            h_slices = (slice(0, -self.window_size),\n\n                        slice(-self.window_size, -self.shift_size),\n\n                        slice(-self.shift_size, None))\n\n            w_slices = (slice(0, -self.window_size),\n\n                        slice(-self.window_size, -self.shift_size),\n\n                        slice(-self.shift_size, None))\n\n            cnt = 0\n\n            for h in h_slices:\n\n                for w in w_slices:\n\n                    img_mask[:, h, w, :] = cnt\n\n                    cnt += 1\n\n\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n        else:\n\n            attn_mask = None\n\n\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n\n\n    def forward(self, x):\n\n        H, W = self.input_resolution\n\n        B, L, C = x.shape\n\n        assert L == H * W, \"input feature has wrong size\"\n\n\n\n        shortcut = x\n\n        x = self.norm1(x)\n\n        x = x.view(B, H, W, C)\n\n\n\n        # cyclic shift\n\n        if self.shift_size > 0:\n\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n\n        else:\n\n            shifted_x = x\n\n\n\n        # partition windows\n\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n\n\n        # W-MSA/SW-MSA\n\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n\n\n        # merge windows\n\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n\n\n        # reverse cyclic shift\n\n        if self.shift_size > 0:\n\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n\n        else:\n\n            x = shifted_x\n\n        x = x.view(B, H * W, C)\n\n\n\n        # FFN\n\n        x = shortcut + self.drop_path(x)\n\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n\n\n        return x\n\n\n\n    def extra_repr(self) -> str:\n\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n            f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n\n\n    def flops(self):\n\n        flops = 0\n\n        H, W = self.input_resolution\n\n        # norm1\n\n        flops += self.dim * H * W\n\n        # W-MSA/SW-MSA\n\n        nW = H * W / self.window_size / self.window_size\n\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n\n        # mlp\n\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n\n        # norm2\n\n        flops += self.dim * H * W\n\n        return flops\n\n\n\n\n\nclass PatchMerging(nn.Module):\n\n    r\"\"\" Patch Merging Layer.\n\n\n\n    Args:\n\n        input_resolution (tuple[int]): Resolution of input feature.\n\n        dim (int): Number of input channels.\n\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n\n    \"\"\"\n\n\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n\n        super().__init__()\n\n        self.input_resolution = input_resolution\n\n        self.dim = dim\n\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n\n        self.norm = norm_layer(4 * dim)\n\n\n\n    def forward(self, x):\n\n        \"\"\"\n\n        x: B, H*W, C\n\n        \"\"\"\n\n        H, W = self.input_resolution\n\n        B, L, C = x.shape\n\n        assert L == H * W, \"input feature has wrong size\"\n\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n\n\n        x = x.view(B, H, W, C)\n\n\n\n        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C\n\n        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C\n\n        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C\n\n        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C\n\n        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C\n\n        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C\n\n\n\n        x = self.norm(x)\n\n        x = self.reduction(x)\n\n\n\n        return x\n\n\n\n    def extra_repr(self) -> str:\n\n        return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n\n\n\n    def flops(self):\n\n        H, W = self.input_resolution\n\n        flops = H * W * self.dim\n\n        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim\n\n        return flops\n\n\n\nclass PatchExpand(nn.Module):\n\n    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):\n\n        super().__init__()\n\n        self.input_resolution = input_resolution\n\n        self.dim = dim\n\n        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()\n\n        self.norm = norm_layer(dim // dim_scale)\n\n\n\n    def forward(self, x):\n\n        \"\"\"\n\n        x: B, H*W, C\n\n        \"\"\"\n\n        H, W = self.input_resolution\n\n        x = self.expand(x)\n\n        B, L, C = x.shape\n\n        assert L == H * W, \"input feature has wrong size\"\n\n\n\n        x = x.view(B, H, W, C)\n\n        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)\n\n        x = x.view(B,-1,C//4)\n\n        x= self.norm(x)\n\n\n\n        return x\n\n\n\nclass FinalPatchExpand_X4(nn.Module):\n\n    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):\n\n        super().__init__()\n\n        self.input_resolution = input_resolution\n\n        self.dim = dim\n\n        self.dim_scale = dim_scale\n\n        self.expand = nn.Linear(dim, 16*dim, bias=False)\n\n        self.output_dim = dim \n\n        self.norm = norm_layer(self.output_dim)\n\n\n\n    def forward(self, x):\n\n        \"\"\"\n\n        x: B, H*W, C\n\n        \"\"\"\n\n        H, W = self.input_resolution\n\n        x = self.expand(x)\n\n        B, L, C = x.shape\n\n        assert L == H * W, \"input feature has wrong size\"\n\n\n\n        x = x.view(B, H, W, C)\n\n        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))\n\n        x = x.view(B,-1,self.output_dim)\n\n        x= self.norm(x)\n\n\n\n        return x\n\n\n\nclass BasicLayer(nn.Module):\n\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n\n\n    Args:\n\n        dim (int): Number of input channels.\n\n        input_resolution (tuple[int]): Input resolution.\n\n        depth (int): Number of blocks.\n\n        num_heads (int): Number of attention heads.\n\n        window_size (int): Local window size.\n\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n\n        drop (float, optional): Dropout rate. Default: 0.0\n\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n\n    \"\"\"\n\n\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n\n\n        super().__init__()\n\n        self.dim = dim\n\n        self.input_resolution = input_resolution\n\n        self.depth = depth\n\n        self.use_checkpoint = use_checkpoint\n\n\n\n        # build blocks\n\n        self.blocks = nn.ModuleList([\n\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n\n                                 num_heads=num_heads, window_size=window_size,\n\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n\n                                 mlp_ratio=mlp_ratio,\n\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n\n                                 drop=drop, attn_drop=attn_drop,\n\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n\n                                 norm_layer=norm_layer)\n\n            for i in range(depth)])\n\n\n\n        # patch merging layer\n\n        if downsample is not None:\n\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n\n        else:\n\n            self.downsample = None\n\n\n\n    def forward(self, x):\n\n        for blk in self.blocks:\n\n            if self.use_checkpoint:\n\n                x = checkpoint.checkpoint(blk, x)\n\n            else:\n\n                x = blk(x)\n\n        if self.downsample is not None:\n\n            x = self.downsample(x)\n\n        return x\n\n\n\n    def extra_repr(self) -> str:\n\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n\n\n    def flops(self):\n\n        flops = 0\n\n        for blk in self.blocks:\n\n            flops += blk.flops()\n\n        if self.downsample is not None:\n\n            flops += self.downsample.flops()\n\n        return flops\n\n\n\nclass BasicLayer_up(nn.Module):\n\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n\n\n    Args:\n\n        dim (int): Number of input channels.\n\n        input_resolution (tuple[int]): Input resolution.\n\n        depth (int): Number of blocks.\n\n        num_heads (int): Number of attention heads.\n\n        window_size (int): Local window size.\n\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n\n        drop (float, optional): Dropout rate. Default: 0.0\n\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n\n        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n\n    \"\"\"\n\n\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n\n                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):\n\n\n\n        super().__init__()\n\n        self.dim = dim\n\n        self.input_resolution = input_resolution\n\n        self.depth = depth\n\n        self.use_checkpoint = use_checkpoint\n\n\n\n        # build blocks\n\n        self.blocks = nn.ModuleList([\n\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n\n                                 num_heads=num_heads, window_size=window_size,\n\n                                 shift_size=0 if (i % 2 == 0) else window_size // 2,\n\n                                 mlp_ratio=mlp_ratio,\n\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n\n                                 drop=drop, attn_drop=attn_drop,\n\n                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n\n                                 norm_layer=norm_layer)\n\n            for i in range(depth)])\n\n\n\n        # patch merging layer\n\n        if upsample is not None:\n\n            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)\n\n        else:\n\n            self.upsample = None\n\n\n\n    def forward(self, x):\n\n        for blk in self.blocks:\n\n            if self.use_checkpoint:\n\n                x = checkpoint.checkpoint(blk, x)\n\n            else:\n\n                x = blk(x)\n\n        if self.upsample is not None:\n\n            x = self.upsample(x)\n\n        return x\n\n\n\nclass PatchEmbed(nn.Module):\n\n    r\"\"\" Image to Patch Embedding\n\n\n\n    Args:\n\n        img_size (int): Image size.  Default: 224.\n\n        patch_size (int): Patch token size. Default: 4.\n\n        in_chans (int): Number of input image channels. Default: 3.\n\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n\n    \"\"\"\n\n\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n\n        super().__init__()\n\n        img_size = to_2tuple(img_size)\n\n        patch_size = to_2tuple(patch_size)\n\n        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]\n\n        self.img_size = img_size\n\n        self.patch_size = patch_size\n\n        self.patches_resolution = patches_resolution\n\n        self.num_patches = patches_resolution[0] * patches_resolution[1]\n\n\n\n        self.in_chans = in_chans\n\n        self.embed_dim = embed_dim\n\n\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        if norm_layer is not None:\n\n            self.norm = norm_layer(embed_dim)\n\n        else:\n\n            self.norm = None\n\n\n\n    def forward(self, x):\n\n        B, C, H, W = x.shape\n\n        # FIXME look at relaxing size constraints\n\n        assert H == self.img_size[0] and W == self.img_size[1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n\n        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C\n\n        if self.norm is not None:\n\n            x = self.norm(x)\n\n        return x\n\n\n\n    def flops(self):\n\n        Ho, Wo = self.patches_resolution\n\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n\n        if self.norm is not None:\n\n            flops += Ho * Wo * self.embed_dim\n\n        return flops\n\n\n\n\n\nclass SwinTransformerSys(nn.Module):\n\n    r\"\"\" Swin Transformer\n\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n\n          https://arxiv.org/pdf/2103.14030\n\n\n\n    Args:\n\n        img_size (int | tuple(int)): Input image size. Default 224\n\n        patch_size (int | tuple(int)): Patch size. Default: 4\n\n        in_chans (int): Number of input image channels. Default: 3\n\n        num_classes (int): Number of classes for classification head. Default: 1000\n\n        embed_dim (int): Patch embedding dimension. Default: 96\n\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n\n        num_heads (tuple(int)): Number of attention heads in different layers.\n\n        window_size (int): Window size. Default: 7\n\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n\n        drop_rate (float): Dropout rate. Default: 0\n\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n\n    \"\"\"\n\n\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n\n                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],\n\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n\n                 use_checkpoint=False, final_upsample=\"expand_first\", **kwargs):\n\n        super().__init__()\n\n\n\n        print(\"SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}\".format(depths,\n\n        depths_decoder,drop_path_rate,num_classes))\n\n\n\n        self.num_classes = num_classes\n\n        self.num_layers = len(depths)\n\n        self.embed_dim = embed_dim\n\n        self.ape = ape\n\n        self.patch_norm = patch_norm\n\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n\n        self.num_features_up = int(embed_dim * 2)\n\n        self.mlp_ratio = mlp_ratio\n\n        self.final_upsample = final_upsample\n\n\n\n        # split image into non-overlapping patches\n\n        self.patch_embed = PatchEmbed(\n\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n\n            norm_layer=norm_layer if self.patch_norm else None)\n\n        num_patches = self.patch_embed.num_patches\n\n        patches_resolution = self.patch_embed.patches_resolution\n\n        self.patches_resolution = patches_resolution\n\n\n\n        # absolute position embedding\n\n        if self.ape:\n\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n\n\n        # stochastic depth\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n\n\n        # build encoder and bottleneck layers\n\n        self.layers = nn.ModuleList()\n\n        for i_layer in range(self.num_layers):\n\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n\n                               input_resolution=(patches_resolution[0] // (2 ** i_layer),\n\n                                                 patches_resolution[1] // (2 ** i_layer)),\n\n                               depth=depths[i_layer],\n\n                               num_heads=num_heads[i_layer],\n\n                               window_size=window_size,\n\n                               mlp_ratio=self.mlp_ratio,\n\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n\n                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],\n\n                               norm_layer=norm_layer,\n\n                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n\n                               use_checkpoint=use_checkpoint)\n\n            self.layers.append(layer)\n\n        \n\n        # build decoder layers\n\n        self.layers_up = nn.ModuleList()\n\n        self.concat_back_dim = nn.ModuleList()\n\n        for i_layer in range(self.num_layers):\n\n            concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),\n\n            int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()\n\n            if i_layer ==0 :\n\n                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),\n\n                patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)\n\n            else:\n\n                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),\n\n                                input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),\n\n                                                    patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),\n\n                                depth=depths[(self.num_layers-1-i_layer)],\n\n                                num_heads=num_heads[(self.num_layers-1-i_layer)],\n\n                                window_size=window_size,\n\n                                mlp_ratio=self.mlp_ratio,\n\n                                qkv_bias=qkv_bias, qk_scale=qk_scale,\n\n                                drop=drop_rate, attn_drop=attn_drop_rate,\n\n                                drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],\n\n                                norm_layer=norm_layer,\n\n                                upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,\n\n                                use_checkpoint=use_checkpoint)\n\n            self.layers_up.append(layer_up)\n\n            self.concat_back_dim.append(concat_linear)\n\n\n\n        self.norm = norm_layer(self.num_features)\n\n        self.norm_up= norm_layer(self.embed_dim)\n\n\n\n        if self.final_upsample == \"expand_first\":\n\n            print(\"---final upsample expand_first---\")\n\n            self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)\n\n            self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)\n\n\n\n        self.apply(self._init_weights)\n\n\n\n    def _init_weights(self, m):\n\n        if isinstance(m, nn.Linear):\n\n            trunc_normal_(m.weight, std=.02)\n\n            if isinstance(m, nn.Linear) and m.bias is not None:\n\n                nn.init.constant_(m.bias, 0)\n\n        elif isinstance(m, nn.LayerNorm):\n\n            nn.init.constant_(m.bias, 0)\n\n            nn.init.constant_(m.weight, 1.0)\n\n\n\n    @torch.jit.ignore\n\n    def no_weight_decay(self):\n\n        return {'absolute_pos_embed'}\n\n\n\n    @torch.jit.ignore\n\n    def no_weight_decay_keywords(self):\n\n        return {'relative_position_bias_table'}\n\n\n\n    #Encoder and Bottleneck\n\n    def forward_features(self, x):\n\n        x = self.patch_embed(x)\n\n        if self.ape:\n\n            x = x + self.absolute_pos_embed\n\n        x = self.pos_drop(x)\n\n        x_downsample = []\n\n\n\n        for layer in self.layers:\n\n            x_downsample.append(x)\n\n            x = layer(x)\n\n\n\n        x = self.norm(x)  # B L C\n\n  \n\n        return x, x_downsample\n\n\n\n    #Dencoder and Skip connection\n\n    def forward_up_features(self, x, x_downsample):\n\n        for inx, layer_up in enumerate(self.layers_up):\n\n            if inx == 0:\n\n                x = layer_up(x)\n\n            else:\n\n                x = torch.cat([x,x_downsample[3-inx]],-1)\n\n                x = self.concat_back_dim[inx](x)\n\n                x = layer_up(x)\n\n\n\n        x = self.norm_up(x)  # B L C\n\n  \n\n        return x\n\n\n\n    def up_x4(self, x):\n\n        H, W = self.patches_resolution\n\n        B, L, C = x.shape\n\n        assert L == H*W, \"input features has wrong size\"\n\n\n\n        if self.final_upsample==\"expand_first\":\n\n            x = self.up(x)\n\n            x = x.view(B,4*H,4*W,-1)\n\n            x = x.permute(0,3,1,2) #B,C,H,W\n\n            x = self.output(x)\n\n            \n\n        return x\n\n\n\n    def forward(self, x):\n\n        x, x_downsample = self.forward_features(x)\n\n        x = self.forward_up_features(x,x_downsample)\n\n        x = self.up_x4(x)\n\n\n\n        return x\n\n\n\n    def flops(self):\n\n        flops = 0\n\n        flops += self.patch_embed.flops()\n\n        for i, layer in enumerate(self.layers):\n\n            flops += layer.flops()\n\n        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)\n\n        flops += self.num_features * self.num_classes\n\n        return flops\n\n\n\n\nlogger = logging.getLogger(__name__)\n\n\n\nclass SwinUnet_config():\n    def __init__(self):\n        self.patch_size = 4\n        self.in_chans = 3\n        self.num_classes = 4\n        self.embed_dim = 96\n        self.depths = [2, 2, 6, 2]\n        self.num_heads = [3, 6, 12, 24]\n        self.window_size = 7\n        self.mlp_ratio = 4.\n        self.qkv_bias = True\n        self.qk_scale = None\n        self.drop_rate = 0.\n        self.drop_path_rate = 0.1\n        self.ape = False\n        self.patch_norm = True\n        self.use_checkpoint = False\n\nclass SwinUnet(nn.Module):\n\n    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):\n\n        super(SwinUnet, self).__init__()\n\n        self.num_classes = num_classes\n\n        self.zero_head = zero_head\n\n        self.config = config\n\n\n\n        self.swin_unet = SwinTransformerSys(img_size=img_size,\n\n                                patch_size=config.patch_size,\n\n                                in_chans=config.in_chans,\n\n                                num_classes=self.num_classes,\n\n                                embed_dim=config.embed_dim,\n\n                                depths=config.depths,\n\n                                num_heads=config.num_heads,\n\n                                window_size=config.window_size,\n\n                                mlp_ratio=config.mlp_ratio,\n\n                                qkv_bias=config.qkv_bias,\n\n                                qk_scale=config.qk_scale,\n\n                                drop_rate=config.drop_rate,\n\n                                drop_path_rate=config.drop_path_rate,\n\n                                ape=config.ape,\n\n                                patch_norm=config.patch_norm,\n\n                                use_checkpoint=config.use_checkpoint)\n\n\n\n    def forward(self, x):\n\n        if x.size()[1] == 1:\n\n            x = x.repeat(1,3,1,1)\n\n        logits = self.swin_unet(x)\n\n        return logits\n\n\n\n    def load_from(self, pretrained_path):\n\n\n        if pretrained_path is not None:\n\n            print(\"pretrained_path:{}\".format(pretrained_path))\n\n            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n            pretrained_dict = torch.load(pretrained_path, map_location=device)\n\n            if \"model\"  not in pretrained_dict:\n\n                print(\"---start load pretrained modle by splitting---\")\n\n                pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}\n\n                for k in list(pretrained_dict.keys()):\n\n                    if \"output\" in k:\n\n                        print(\"delete key:{}\".format(k))\n\n                        del pretrained_dict[k]\n\n                msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)\n\n                # print(msg)\n\n                return\n\n            pretrained_dict = pretrained_dict['model']\n\n            print(\"---start load pretrained modle of swin encoder---\")\n\n\n\n            model_dict = self.swin_unet.state_dict()\n\n            full_dict = copy.deepcopy(pretrained_dict)\n\n            for k, v in pretrained_dict.items():\n\n                if \"layers.\" in k:\n\n                    current_layer_num = 3-int(k[7:8])\n\n                    current_k = \"layers_up.\" + str(current_layer_num) + k[8:]\n\n                    full_dict.update({current_k:v})\n\n            for k in list(full_dict.keys()):\n\n                if k in model_dict:\n\n                    if full_dict[k].shape != model_dict[k].shape:\n\n                        print(\"delete:{};shape pretrain:{};shape model:{}\".format(k,v.shape,model_dict[k].shape))\n\n                        del full_dict[k]\n\n\n\n            msg = self.swin_unet.load_state_dict(full_dict, strict=False)\n\n            # print(msg)\n\n        else:\n\n            print(\"none pretrain\")\n"
  },
  {
    "path": "model/transunet.py",
    "content": "# The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660ecda1d59\n# coding=utf-8\n\nfrom __future__ import absolute_import\n\nfrom __future__ import division\n\nfrom __future__ import print_function\n\n\nimport ml_collections\n\nimport copy\n\nimport logging\n\nimport math\n\nfrom collections import OrderedDict\n\nfrom os.path import join as pjoin\n\n\n\nimport torch\n\nimport torch.nn as nn\n\nimport torch.nn.functional as F\n\nimport numpy as np\n\n\n\nfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm\n\nfrom torch.nn.modules.utils import _pair\n\nfrom scipy import ndimage\n\n\n\n\n\n\nlogger = logging.getLogger(__name__)\n\n\n\n\n\nATTENTION_Q = \"MultiHeadDotProductAttention_1/query\"\n\nATTENTION_K = \"MultiHeadDotProductAttention_1/key\"\n\nATTENTION_V = \"MultiHeadDotProductAttention_1/value\"\n\nATTENTION_OUT = \"MultiHeadDotProductAttention_1/out\"\n\nFC_0 = \"MlpBlock_3/Dense_0\"\n\nFC_1 = \"MlpBlock_3/Dense_1\"\n\nATTENTION_NORM = \"LayerNorm_0\"\n\nMLP_NORM = \"LayerNorm_2\"\n\n\n\n\n\ndef np2th(weights, conv=False):\n\n    \"\"\"Possibly convert HWIO to OIHW.\"\"\"\n\n    if conv:\n\n        weights = weights.transpose([3, 2, 0, 1])\n\n    return torch.from_numpy(weights)\n\n\n\n\n\ndef swish(x):\n\n    return x * torch.sigmoid(x)\n\n\n\n\n\nACT2FN = {\"gelu\": torch.nn.functional.gelu, \"relu\": torch.nn.functional.relu, \"swish\": swish}\n\n\n\n\n\nclass Attention(nn.Module):\n\n    def __init__(self, config, vis):\n\n        super(Attention, self).__init__()\n\n        self.vis = vis\n\n        self.num_attention_heads = config.transformer[\"num_heads\"]\n\n        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)\n\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n\n\n        self.query = Linear(config.hidden_size, self.all_head_size)\n\n        self.key = Linear(config.hidden_size, self.all_head_size)\n\n        self.value = Linear(config.hidden_size, self.all_head_size)\n\n\n\n        self.out = Linear(config.hidden_size, config.hidden_size)\n\n        self.attn_dropout = Dropout(config.transformer[\"attention_dropout_rate\"])\n\n        self.proj_dropout = Dropout(config.transformer[\"attention_dropout_rate\"])\n\n\n\n        self.softmax = Softmax(dim=-1)\n\n\n\n    def transpose_for_scores(self, x):\n\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n\n        x = x.view(*new_x_shape)\n\n        return x.permute(0, 2, 1, 3)\n\n\n\n    def forward(self, hidden_states):\n\n        mixed_query_layer = self.query(hidden_states)\n\n        mixed_key_layer = self.key(hidden_states)\n\n        mixed_value_layer = self.value(hidden_states)\n\n\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        key_layer = self.transpose_for_scores(mixed_key_layer)\n\n        value_layer = self.transpose_for_scores(mixed_value_layer)\n\n\n\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n\n        attention_probs = self.softmax(attention_scores)\n\n        weights = attention_probs if self.vis else None\n\n        attention_probs = self.attn_dropout(attention_probs)\n\n\n\n        context_layer = torch.matmul(attention_probs, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        attention_output = self.out(context_layer)\n\n        attention_output = self.proj_dropout(attention_output)\n\n        return attention_output, weights\n\n\n\n\n\nclass Mlp(nn.Module):\n\n    def __init__(self, config):\n\n        super(Mlp, self).__init__()\n\n        self.fc1 = Linear(config.hidden_size, config.transformer[\"mlp_dim\"])\n\n        self.fc2 = Linear(config.transformer[\"mlp_dim\"], config.hidden_size)\n\n        self.act_fn = ACT2FN[\"gelu\"]\n\n        self.dropout = Dropout(config.transformer[\"dropout_rate\"])\n\n\n\n        self._init_weights()\n\n\n\n    def _init_weights(self):\n\n        nn.init.xavier_uniform_(self.fc1.weight)\n\n        nn.init.xavier_uniform_(self.fc2.weight)\n\n        nn.init.normal_(self.fc1.bias, std=1e-6)\n\n        nn.init.normal_(self.fc2.bias, std=1e-6)\n\n\n\n    def forward(self, x):\n\n        x = self.fc1(x)\n\n        x = self.act_fn(x)\n\n        x = self.dropout(x)\n\n        x = self.fc2(x)\n\n        x = self.dropout(x)\n\n        return x\n\n\n\n\n\nclass Embeddings(nn.Module):\n\n    \"\"\"Construct the embeddings from patch, position embeddings.\n\n    \"\"\"\n\n    def __init__(self, config, img_size, in_channels=3):\n\n        super(Embeddings, self).__init__()\n\n        self.hybrid = None\n\n        self.config = config\n\n        img_size = _pair(img_size)\n\n\n\n        if config.patches.get(\"grid\") is not None:   # ResNet\n\n            grid_size = config.patches[\"grid\"]\n\n            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])\n\n            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)\n\n            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  \n\n            self.hybrid = True\n\n        else:\n\n            patch_size = _pair(config.patches[\"size\"])\n\n            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])\n\n            self.hybrid = False\n\n\n\n        if self.hybrid:\n\n            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)\n\n            in_channels = self.hybrid_model.width * 16\n\n        self.patch_embeddings = Conv2d(in_channels=in_channels,\n\n                                       out_channels=config.hidden_size,\n\n                                       kernel_size=patch_size,\n\n                                       stride=patch_size)\n\n        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))\n\n\n\n        self.dropout = Dropout(config.transformer[\"dropout_rate\"])\n\n\n\n\n\n    def forward(self, x):\n\n        if self.hybrid:\n\n            x, features = self.hybrid_model(x)\n\n        else:\n\n            features = None\n\n        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))\n\n        x = x.flatten(2)\n\n        x = x.transpose(-1, -2)  # (B, n_patches, hidden)\n\n\n\n        embeddings = x + self.position_embeddings\n\n        embeddings = self.dropout(embeddings)\n\n        return embeddings, features\n\n\n\n\n\nclass Block(nn.Module):\n\n    def __init__(self, config, vis):\n\n        super(Block, self).__init__()\n\n        self.hidden_size = config.hidden_size\n\n        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)\n\n        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)\n\n        self.ffn = Mlp(config)\n\n        self.attn = Attention(config, vis)\n\n\n\n    def forward(self, x):\n\n        h = x\n\n        x = self.attention_norm(x)\n\n        x, weights = self.attn(x)\n\n        x = x + h\n\n\n\n        h = x\n\n        x = self.ffn_norm(x)\n\n        x = self.ffn(x)\n\n        x = x + h\n\n        return x, weights\n\n\n\n    def load_from(self, weights, n_block):\n\n        ROOT = f\"Transformer/encoderblock_{n_block}\"\n\n        with torch.no_grad():\n\n            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, \"kernel\")]).view(self.hidden_size, self.hidden_size).t()\n\n            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, \"kernel\")]).view(self.hidden_size, self.hidden_size).t()\n\n            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, \"kernel\")]).view(self.hidden_size, self.hidden_size).t()\n\n            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, \"kernel\")]).view(self.hidden_size, self.hidden_size).t()\n\n\n\n            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, \"bias\")]).view(-1)\n\n            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, \"bias\")]).view(-1)\n\n            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, \"bias\")]).view(-1)\n\n            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, \"bias\")]).view(-1)\n\n\n\n            self.attn.query.weight.copy_(query_weight)\n\n            self.attn.key.weight.copy_(key_weight)\n\n            self.attn.value.weight.copy_(value_weight)\n\n            self.attn.out.weight.copy_(out_weight)\n\n            self.attn.query.bias.copy_(query_bias)\n\n            self.attn.key.bias.copy_(key_bias)\n\n            self.attn.value.bias.copy_(value_bias)\n\n            self.attn.out.bias.copy_(out_bias)\n\n\n\n            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, \"kernel\")]).t()\n\n            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, \"kernel\")]).t()\n\n            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, \"bias\")]).t()\n\n            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, \"bias\")]).t()\n\n\n\n            self.ffn.fc1.weight.copy_(mlp_weight_0)\n\n            self.ffn.fc2.weight.copy_(mlp_weight_1)\n\n            self.ffn.fc1.bias.copy_(mlp_bias_0)\n\n            self.ffn.fc2.bias.copy_(mlp_bias_1)\n\n\n\n            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, \"scale\")]))\n\n            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, \"bias\")]))\n\n            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, \"scale\")]))\n\n            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, \"bias\")]))\n\n\n\n\n\nclass Encoder(nn.Module):\n\n    def __init__(self, config, vis):\n\n        super(Encoder, self).__init__()\n\n        self.vis = vis\n\n        self.layer = nn.ModuleList()\n\n        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)\n\n        for _ in range(config.transformer[\"num_layers\"]):\n\n            layer = Block(config, vis)\n\n            self.layer.append(copy.deepcopy(layer))\n\n\n\n    def forward(self, hidden_states):\n\n        attn_weights = []\n\n        for layer_block in self.layer:\n\n            hidden_states, weights = layer_block(hidden_states)\n\n            if self.vis:\n\n                attn_weights.append(weights)\n\n        encoded = self.encoder_norm(hidden_states)\n\n        return encoded, attn_weights\n\n\n\n\n\nclass Transformer(nn.Module):\n\n    def __init__(self, config, img_size, vis):\n\n        super(Transformer, self).__init__()\n\n        self.embeddings = Embeddings(config, img_size=img_size)\n\n        self.encoder = Encoder(config, vis)\n\n\n\n    def forward(self, input_ids):\n\n        embedding_output, features = self.embeddings(input_ids)\n\n        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)\n\n        return encoded, attn_weights, features\n\n\n\n\n\nclass Conv2dReLU(nn.Sequential):\n\n    def __init__(\n\n            self,\n\n            in_channels,\n\n            out_channels,\n\n            kernel_size,\n\n            padding=0,\n\n            stride=1,\n\n            use_batchnorm=True,\n\n    ):\n\n        conv = nn.Conv2d(\n\n            in_channels,\n\n            out_channels,\n\n            kernel_size,\n\n            stride=stride,\n\n            padding=padding,\n\n            bias=not (use_batchnorm),\n\n        )\n\n        relu = nn.ReLU(inplace=True)\n\n\n\n        bn = nn.BatchNorm2d(out_channels)\n\n\n\n        super(Conv2dReLU, self).__init__(conv, bn, relu)\n\n\n\n\n\nclass DecoderBlock(nn.Module):\n\n    def __init__(\n\n            self,\n\n            in_channels,\n\n            out_channels,\n\n            skip_channels=0,\n\n            use_batchnorm=True,\n\n    ):\n\n        super().__init__()\n\n        self.conv1 = Conv2dReLU(\n\n            in_channels + skip_channels,\n\n            out_channels,\n\n            kernel_size=3,\n\n            padding=1,\n\n            use_batchnorm=use_batchnorm,\n\n        )\n\n        self.conv2 = Conv2dReLU(\n\n            out_channels,\n\n            out_channels,\n\n            kernel_size=3,\n\n            padding=1,\n\n            use_batchnorm=use_batchnorm,\n\n        )\n\n        self.up = nn.UpsamplingBilinear2d(scale_factor=2)\n\n\n\n    def forward(self, x, skip=None):\n\n        x = self.up(x)\n\n        if skip is not None:\n\n            x = torch.cat([x, skip], dim=1)\n\n        x = self.conv1(x)\n\n        x = self.conv2(x)\n\n        return x\n\n\n\n\n\nclass SegmentationHead(nn.Sequential):\n\n\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):\n\n        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)\n\n        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()\n\n        super().__init__(conv2d, upsampling)\n\n\n\n\n\nclass DecoderCup(nn.Module):\n\n    def __init__(self, config):\n\n        super().__init__()\n\n        self.config = config\n\n        head_channels = 512\n\n        self.conv_more = Conv2dReLU(\n\n            config.hidden_size,\n\n            head_channels,\n\n            kernel_size=3,\n\n            padding=1,\n\n            use_batchnorm=True,\n\n        )\n\n        decoder_channels = config.decoder_channels\n\n        in_channels = [head_channels] + list(decoder_channels[:-1])\n\n        out_channels = decoder_channels\n\n\n\n        if self.config.n_skip != 0:\n\n            skip_channels = self.config.skip_channels\n\n            for i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skip\n\n                skip_channels[3-i]=0\n\n\n\n        else:\n\n            skip_channels=[0,0,0,0]\n\n\n\n        blocks = [\n\n            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)\n\n        ]\n\n        self.blocks = nn.ModuleList(blocks)\n\n\n\n    def forward(self, hidden_states, features=None):\n\n        B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)\n\n        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))\n\n        x = hidden_states.permute(0, 2, 1)\n\n        x = x.contiguous().view(B, hidden, h, w)\n\n        x = self.conv_more(x)\n\n        for i, decoder_block in enumerate(self.blocks):\n\n            if features is not None:\n\n                skip = features[i] if (i < self.config.n_skip) else None\n\n            else:\n\n                skip = None\n\n            x = decoder_block(x, skip=skip)\n\n        return x\n\n\n\n\n\nclass VisionTransformer(nn.Module):\n\n    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):\n\n        super(VisionTransformer, self).__init__()\n\n        self.num_classes = num_classes\n\n        self.zero_head = zero_head\n\n        self.classifier = config.classifier\n\n        self.transformer = Transformer(config, img_size, vis)\n\n        self.decoder = DecoderCup(config)\n\n        self.segmentation_head = SegmentationHead(\n\n            in_channels=config['decoder_channels'][-1],\n\n            out_channels=config['n_classes'],\n\n            kernel_size=3,\n\n        )\n\n        self.config = config\n\n\n\n    def forward(self, x):\n\n        if x.size()[1] == 1:\n\n            x = x.repeat(1,3,1,1)\n\n        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)\n\n        x = self.decoder(x, features)\n\n        logits = self.segmentation_head(x)\n\n        return logits\n\n\n\n    def load_from(self, weights):\n\n        with torch.no_grad():\n\n\n\n            res_weight = weights\n\n            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights[\"embedding/kernel\"], conv=True))\n\n            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights[\"embedding/bias\"]))\n\n\n\n            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights[\"Transformer/encoder_norm/scale\"]))\n\n            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights[\"Transformer/encoder_norm/bias\"]))\n\n\n\n            posemb = np2th(weights[\"Transformer/posembed_input/pos_embedding\"])\n\n\n\n            posemb_new = self.transformer.embeddings.position_embeddings\n\n            if posemb.size() == posemb_new.size():\n\n                self.transformer.embeddings.position_embeddings.copy_(posemb)\n\n            elif posemb.size()[1]-1 == posemb_new.size()[1]:\n\n                posemb = posemb[:, 1:]\n\n                self.transformer.embeddings.position_embeddings.copy_(posemb)\n\n            else:\n\n                logger.info(\"load_pretrained: resized variant: %s to %s\" % (posemb.size(), posemb_new.size()))\n\n                ntok_new = posemb_new.size(1)\n\n                if self.classifier == \"seg\":\n\n                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]\n\n                gs_old = int(np.sqrt(len(posemb_grid)))\n\n                gs_new = int(np.sqrt(ntok_new))\n\n                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))\n\n                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)\n\n                zoom = (gs_new / gs_old, gs_new / gs_old, 1)\n\n                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np\n\n                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)\n\n                posemb = posemb_grid\n\n                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))\n\n\n\n            # Encoder whole\n\n            for bname, block in self.transformer.encoder.named_children():\n\n                for uname, unit in block.named_children():\n\n                    unit.load_from(weights, n_block=uname)\n\n\n\n            if self.transformer.embeddings.hybrid:\n\n                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight[\"conv_root/kernel\"], conv=True))\n\n                gn_weight = np2th(res_weight[\"gn_root/scale\"]).view(-1)\n\n                gn_bias = np2th(res_weight[\"gn_root/bias\"]).view(-1)\n\n                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)\n\n                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)\n\n\n\n                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():\n\n                    for uname, unit in block.named_children():\n\n                        unit.load_from(res_weight, n_block=bname, n_unit=uname)\n\n\n\n\ndef get_b16_config():\n\n    \"\"\"Returns the ViT-B/16 configuration.\"\"\"\n\n    config = ml_collections.ConfigDict()\n\n    config.patches = ml_collections.ConfigDict({'size': (16, 16)})\n\n    config.hidden_size = 768\n\n    config.transformer = ml_collections.ConfigDict()\n\n    config.transformer.mlp_dim = 3072\n\n    config.transformer.num_heads = 12\n\n    config.transformer.num_layers = 12\n\n    config.transformer.attention_dropout_rate = 0.0\n\n    config.transformer.dropout_rate = 0.1\n\n\n\n    config.classifier = 'seg'\n\n    config.representation_size = None\n\n    config.resnet_pretrained_path = None\n\n    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'\n\n    config.patch_size = 16\n\n\n    config.skip_channels = [0, 0, 0, 0]\n\n    config.decoder_channels = (256, 128, 64, 16)\n\n    config.n_classes = 2\n\n    config.activation = 'softmax'\n\n    return config\n\n\n\n\n\ndef get_testing():\n\n    \"\"\"Returns a minimal configuration for testing.\"\"\"\n\n    config = ml_collections.ConfigDict()\n\n    config.patches = ml_collections.ConfigDict({'size': (16, 16)})\n\n    config.hidden_size = 1\n\n    config.transformer = ml_collections.ConfigDict()\n\n    config.transformer.mlp_dim = 1\n\n    config.transformer.num_heads = 1\n\n    config.transformer.num_layers = 1\n\n    config.transformer.attention_dropout_rate = 0.0\n\n    config.transformer.dropout_rate = 0.1\n\n    config.classifier = 'token'\n\n    config.representation_size = None\n\n    return config\n\n\n\ndef get_r50_b16_config():\n\n    \"\"\"Returns the Resnet50 + ViT-B/16 configuration.\"\"\"\n\n    config = get_b16_config()\n\n    config.patches.grid = (16, 16)\n\n    config.resnet = ml_collections.ConfigDict()\n\n    config.resnet.num_layers = (3, 4, 9)\n\n    config.resnet.width_factor = 1\n\n\n\n    config.classifier = 'seg'\n\n    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'\n\n    config.decoder_channels = (256, 128, 64, 16)\n\n    config.skip_channels = [512, 256, 64, 16]\n\n    config.n_classes = 2\n\n    config.n_skip = 3\n\n    config.activation = 'softmax'\n\n\n\n    return config\n\n\n\n\n\ndef get_b32_config():\n\n    \"\"\"Returns the ViT-B/32 configuration.\"\"\"\n\n    config = get_b16_config()\n\n    config.patches.size = (32, 32)\n\n    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'\n\n    return config\n\n\n\n\n\ndef get_l16_config():\n\n    \"\"\"Returns the ViT-L/16 configuration.\"\"\"\n\n    config = ml_collections.ConfigDict()\n\n    config.patches = ml_collections.ConfigDict({'size': (16, 16)})\n\n    config.hidden_size = 1024\n\n    config.transformer = ml_collections.ConfigDict()\n\n    config.transformer.mlp_dim = 4096\n\n    config.transformer.num_heads = 16\n\n    config.transformer.num_layers = 24\n\n    config.transformer.attention_dropout_rate = 0.0\n\n    config.transformer.dropout_rate = 0.1\n\n    config.representation_size = None\n\n\n\n    # custom\n\n    config.classifier = 'seg'\n\n    config.resnet_pretrained_path = None\n\n    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'\n\n    config.decoder_channels = (256, 128, 64, 16)\n\n    config.n_classes = 2\n\n    config.activation = 'softmax'\n\n    return config\n\n\n\n\n\ndef get_r50_l16_config():\n\n    \"\"\"Returns the Resnet50 + ViT-L/16 configuration. customized \"\"\"\n\n    config = get_l16_config()\n\n    config.patches.grid = (16, 16)\n\n    config.resnet = ml_collections.ConfigDict()\n\n    config.resnet.num_layers = (3, 4, 9)\n\n    config.resnet.width_factor = 1\n\n\n\n    config.classifier = 'seg'\n\n    config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'\n\n    config.decoder_channels = (256, 128, 64, 16)\n\n    config.skip_channels = [512, 256, 64, 16]\n\n    config.n_classes = 2\n\n    config.activation = 'softmax'\n\n    return config\n\n\n\n\n\ndef get_l32_config():\n\n    \"\"\"Returns the ViT-L/32 configuration.\"\"\"\n\n    config = get_l16_config()\n\n    config.patches.size = (32, 32)\n\n    return config\n\n\n\n\n\ndef get_h14_config():\n\n    \"\"\"Returns the ViT-L/16 configuration.\"\"\"\n\n    config = ml_collections.ConfigDict()\n\n    config.patches = ml_collections.ConfigDict({'size': (14, 14)})\n\n    config.hidden_size = 1280\n\n    config.transformer = ml_collections.ConfigDict()\n\n    config.transformer.mlp_dim = 5120\n\n    config.transformer.num_heads = 16\n\n    config.transformer.num_layers = 32\n\n    config.transformer.attention_dropout_rate = 0.0\n\n    config.transformer.dropout_rate = 0.1\n\n    config.classifier = 'token'\n\n    config.representation_size = None\n\n\n\n    return config\n\n\n\n\n\n\n\n\n\nCONFIGS = {\n\n    'ViT-B_16': get_b16_config(),\n\n    'ViT-B_32': get_b32_config(),\n\n    'ViT-L_16': get_l16_config(),\n\n    'ViT-L_32': get_l32_config(),\n\n    'ViT-H_14': get_h14_config(),\n\n    'R50-ViT-B_16': get_r50_b16_config(),\n\n    'R50-ViT-L_16': get_r50_l16_config(),\n\n    'testing': get_testing(),\n\n}\n\n\n\n\n\n\n\ndef np2th(weights, conv=False):\n\n    \"\"\"Possibly convert HWIO to OIHW.\"\"\"\n\n    if conv:\n\n        weights = weights.transpose([3, 2, 0, 1])\n\n    return torch.from_numpy(weights)\n\n\n\n\n\nclass StdConv2d(nn.Conv2d):\n\n\n\n    def forward(self, x):\n\n        w = self.weight\n\n        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)\n\n        w = (w - m) / torch.sqrt(v + 1e-5)\n\n        return F.conv2d(x, w, self.bias, self.stride, self.padding,\n\n                        self.dilation, self.groups)\n\n\n\n\n\ndef conv3x3(cin, cout, stride=1, groups=1, bias=False):\n\n    return StdConv2d(cin, cout, kernel_size=3, stride=stride,\n\n                     padding=1, bias=bias, groups=groups)\n\n\n\n\n\ndef conv1x1(cin, cout, stride=1, bias=False):\n\n    return StdConv2d(cin, cout, kernel_size=1, stride=stride,\n\n                     padding=0, bias=bias)\n\n\n\n\n\nclass PreActBottleneck(nn.Module):\n\n    \"\"\"Pre-activation (v2) bottleneck block.\n\n    \"\"\"\n\n\n\n    def __init__(self, cin, cout=None, cmid=None, stride=1):\n\n        super().__init__()\n\n        cout = cout or cin\n\n        cmid = cmid or cout//4\n\n\n\n        self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)\n\n        self.conv1 = conv1x1(cin, cmid, bias=False)\n\n        self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)\n\n        self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!\n\n        self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)\n\n        self.conv3 = conv1x1(cmid, cout, bias=False)\n\n        self.relu = nn.ReLU(inplace=True)\n\n\n\n        if (stride != 1 or cin != cout):\n\n            # Projection also with pre-activation according to paper.\n\n            self.downsample = conv1x1(cin, cout, stride, bias=False)\n\n            self.gn_proj = nn.GroupNorm(cout, cout)\n\n\n\n    def forward(self, x):\n\n\n\n        # Residual branch\n\n        residual = x\n\n        if hasattr(self, 'downsample'):\n\n            residual = self.downsample(x)\n\n            residual = self.gn_proj(residual)\n\n\n\n        # Unit's branch\n\n        y = self.relu(self.gn1(self.conv1(x)))\n\n        y = self.relu(self.gn2(self.conv2(y)))\n\n        y = self.gn3(self.conv3(y))\n\n\n\n        y = self.relu(residual + y)\n\n        return y\n\n\n\n    def load_from(self, weights, n_block, n_unit):\n\n        conv1_weight = np2th(weights[pjoin(n_block, n_unit, \"conv1/kernel\")], conv=True)\n\n        conv2_weight = np2th(weights[pjoin(n_block, n_unit, \"conv2/kernel\")], conv=True)\n\n        conv3_weight = np2th(weights[pjoin(n_block, n_unit, \"conv3/kernel\")], conv=True)\n\n\n\n        gn1_weight = np2th(weights[pjoin(n_block, n_unit, \"gn1/scale\")])\n\n        gn1_bias = np2th(weights[pjoin(n_block, n_unit, \"gn1/bias\")])\n\n\n\n        gn2_weight = np2th(weights[pjoin(n_block, n_unit, \"gn2/scale\")])\n\n        gn2_bias = np2th(weights[pjoin(n_block, n_unit, \"gn2/bias\")])\n\n\n\n        gn3_weight = np2th(weights[pjoin(n_block, n_unit, \"gn3/scale\")])\n\n        gn3_bias = np2th(weights[pjoin(n_block, n_unit, \"gn3/bias\")])\n\n\n\n        self.conv1.weight.copy_(conv1_weight)\n\n        self.conv2.weight.copy_(conv2_weight)\n\n        self.conv3.weight.copy_(conv3_weight)\n\n\n\n        self.gn1.weight.copy_(gn1_weight.view(-1))\n\n        self.gn1.bias.copy_(gn1_bias.view(-1))\n\n\n\n        self.gn2.weight.copy_(gn2_weight.view(-1))\n\n        self.gn2.bias.copy_(gn2_bias.view(-1))\n\n\n\n        self.gn3.weight.copy_(gn3_weight.view(-1))\n\n        self.gn3.bias.copy_(gn3_bias.view(-1))\n\n\n\n        if hasattr(self, 'downsample'):\n\n            proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, \"conv_proj/kernel\")], conv=True)\n\n            proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, \"gn_proj/scale\")])\n\n            proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, \"gn_proj/bias\")])\n\n\n\n            self.downsample.weight.copy_(proj_conv_weight)\n\n            self.gn_proj.weight.copy_(proj_gn_weight.view(-1))\n\n            self.gn_proj.bias.copy_(proj_gn_bias.view(-1))\n\n\n\nclass ResNetV2(nn.Module):\n\n    \"\"\"Implementation of Pre-activation (v2) ResNet mode.\"\"\"\n\n\n\n    def __init__(self, block_units, width_factor):\n\n        super().__init__()\n\n        width = int(64 * width_factor)\n\n        self.width = width\n\n\n\n        self.root = nn.Sequential(OrderedDict([\n\n            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),\n\n            ('gn', nn.GroupNorm(32, width, eps=1e-6)),\n\n            ('relu', nn.ReLU(inplace=True)),\n\n            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))\n\n        ]))\n\n\n\n        self.body = nn.Sequential(OrderedDict([\n\n            ('block1', nn.Sequential(OrderedDict(\n\n                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +\n\n                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],\n\n                ))),\n\n            ('block2', nn.Sequential(OrderedDict(\n\n                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +\n\n                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],\n\n                ))),\n\n            ('block3', nn.Sequential(OrderedDict(\n\n                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +\n\n                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],\n\n                ))),\n\n        ]))\n\n\n\n    def forward(self, x):\n\n        features = []\n\n        b, c, in_size, _ = x.size()\n\n        x = self.root(x)\n\n        features.append(x)\n\n        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)\n\n        for i in range(len(self.body)-1):\n\n            x = self.body[i](x)\n\n            right_size = int(in_size / 4 / (i+1))\n\n            if x.size()[2] != right_size:\n\n                pad = right_size - x.size()[2]\n\n                assert pad < 3 and pad > 0, \"x {} should {}\".format(x.size(), right_size)\n\n                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)\n\n                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]\n\n            else:\n\n                feat = x\n\n            features.append(feat)\n\n        x = self.body[-1](x)\n\n        return x, features[::-1]\n"
  },
  {
    "path": "model/unet_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pdb \n\nclass DoubleConv(nn.Module):\n\n    \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n\n\n\n    def __init__(self, in_channels, out_channels, mid_channels=None):\n\n        super().__init__()\n\n        if not mid_channels:\n\n            mid_channels = out_channels\n\n        self.double_conv = nn.Sequential(\n\n            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),\n\n            nn.BatchNorm2d(mid_channels),\n\n            nn.ReLU(inplace=True),\n\n            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),\n\n            nn.BatchNorm2d(out_channels),\n\n            nn.ReLU(inplace=True)\n\n        )\n\n\n\n    def forward(self, x):\n\n        return self.double_conv(x)\n\n\n\n\n\nclass Down(nn.Module):\n\n    \"\"\"Downscaling with maxpool then double conv\"\"\"\n\n\n\n    def __init__(self, in_channels, out_channels):\n\n        super().__init__()\n\n        self.maxpool_conv = nn.Sequential(\n\n            nn.MaxPool2d(2),\n\n            DoubleConv(in_channels, out_channels)\n\n        )\n\n\n\n    def forward(self, x):\n\n        return self.maxpool_conv(x)\n\n\n\n\n\nclass Up(nn.Module):\n\n    \"\"\"Upscaling then double conv\"\"\"\n\n\n\n    def __init__(self, in_channels, out_channels, bilinear=True):\n\n        super().__init__()\n\n\n\n        # if bilinear, use the normal convolutions to reduce the number of channels\n\n        if bilinear:\n\n            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n\n            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n\n        else:\n\n            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)\n\n            self.conv = DoubleConv(in_channels, out_channels)\n\n\n\n\n\n    def forward(self, x1, x2):\n\n        x1 = self.up(x1)\n\n        # input is CHW\n\n        diffY = x2.size()[2] - x1.size()[2]\n\n        diffX = x2.size()[3] - x1.size()[3]\n\n\n\n        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,\n\n                        diffY // 2, diffY - diffY // 2])\n\n        # if you have padding issues, see\n\n        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n\n        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n\n        x = torch.cat([x2, x1], dim=1)\n\n        return self.conv(x)\n\n\n\n\n\nclass OutConv(nn.Module):\n\n    def __init__(self, in_channels, out_channels):\n\n        super(OutConv, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n\n\n\n    def forward(self, x):\n\n        return self.conv(x)\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)\n    \n\n\nclass BasicBlock(nn.Module):\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(inplanes)\n        self.relu = nn.ReLU(inplace=True)\n\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or inplanes != planes:\n            self.shortcut = nn.Sequential(\n                    nn.BatchNorm2d(inplanes),\n                    self.relu,\n                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)\n                    )\n\n    def forward(self, x): \n        residue = x \n\n        out = self.bn1(x)\n        out = self.relu(out)\n        out = self.conv1(out)\n\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        out += self.shortcut(residue)\n\n        return out \n\nclass BottleneckBlock(nn.Module):\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n        self.conv1 = conv1x1(inplanes, planes//4, stride=1)\n        self.bn1 = nn.BatchNorm2d(inplanes)\n        self.relu = nn.ReLU(inplace=True)\n\n        self.conv2 = conv3x3(planes//4, planes//4, stride=stride)\n        self.bn2 = nn.BatchNorm2d(planes//4)\n\n        self.conv3 = conv1x1(planes//4, planes, stride=1)\n        self.bn3 = nn.BatchNorm2d(planes//4)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or inplanes != planes:\n            self.shortcut = nn.Sequential(\n                    nn.BatchNorm2d(inplanes),\n                    self.relu,\n                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)\n                    )\n\n    def forward(self, x):\n        residue = x\n\n        out = self.bn1(x)\n        out = self.relu(out)\n        out = self.conv1(out)\n\n        out = self.bn2(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        out = self.bn3(out)\n        out = self.relu(out)\n        out = self.conv3(out)\n\n        out += self.shortcut(residue)\n\n        return out\n\n\n\nclass inconv(nn.Module):\n    def __init__(self, in_ch, out_ch, bottleneck=False):\n        super().__init__()\n        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)\n        self.relu = nn.ReLU(inplace=True)\n\n        if bottleneck:\n            self.conv2 = BottleneckBlock(out_ch, out_ch)\n        else:\n            self.conv2 = BasicBlock(out_ch, out_ch)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.conv2(out)\n\n        return out\n\n\nclass down_block(nn.Module):\n    def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, pool=True):\n        super().__init__()\n\n        block_list = []\n\n        if bottleneck:\n            block = BottleneckBlock\n        else:\n            block = BasicBlock\n\n\n        if pool:\n            block_list.append(nn.MaxPool2d(scale))\n            block_list.append(block(in_ch, out_ch))\n        else:\n            block_list.append(block(in_ch, out_ch, stride=2))\n\n        for i in range(num_block-1):\n            block_list.append(block(out_ch, out_ch, stride=1))\n\n        self.conv = nn.Sequential(*block_list)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\n\n\nclass up_block(nn.Module):\n    def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=False):\n        super().__init__()\n        self.scale=scale\n\n        self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)\n\n        if bottleneck:\n            block = BottleneckBlock\n        else:\n            block = BasicBlock\n\n\n        block_list = []\n        block_list.append(block(2*out_ch, out_ch))\n\n        for i in range(num_block-1):\n            block_list.append(block(out_ch, out_ch))\n\n        self.conv = nn.Sequential(*block_list)\n\n    def forward(self, x1, x2):\n        x1 = F.interpolate(x1, scale_factor=self.scale, mode='bilinear', align_corners=True)\n        x1 = self.conv_ch(x1)\n\n        out = torch.cat([x2, x1], dim=1)\n        out = self.conv(out)\n\n        return out\n\n\n\n"
  },
  {
    "path": "model/utnet.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .unet_utils import up_block, down_block\nfrom .conv_trans_utils import *\n\nimport pdb\n\n\n\nclass UTNet(nn.Module):\n    \n    def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):\n        super().__init__()\n\n        self.aux_loss = aux_loss\n        self.inc = [BasicBlock(in_chan, base_chan)]\n        if '0' in block_list:\n            self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))\n            self.up4 = up_block_trans(2*base_chan, base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-4], dim_head=base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        \n        else:\n            self.inc.append(BasicBlock(base_chan, base_chan))\n            self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)\n        self.inc = nn.Sequential(*self.inc)\n\n\n        if '1' in block_list:\n            self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n            self.up3 = up_block_trans(4*base_chan, 2*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-3], dim_head=2*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)\n            self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)\n\n        if '2' in block_list:\n            self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n            self.up2 = up_block_trans(8*base_chan, 4*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-2], dim_head=4*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        else:\n            self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)\n            self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)\n\n        if '3' in block_list:\n            self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n            self.up1 = up_block_trans(16*base_chan, 8*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-1], dim_head=8*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n\n        else:\n            self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)\n            self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)\n\n        if '4' in block_list:\n            self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)\n\n\n        self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)\n\n        if aux_loss:\n            self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)\n            self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)\n            self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)\n            \n\n\n    def forward(self, x):\n        \n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        \n        if self.aux_loss:\n            out = self.up1(x5, x4)\n            out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up2(out, x3)\n            out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up3(out, x2)\n            out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up4(out, x1)\n            out = self.outc(out)\n\n            return out, out3, out2, out1\n\n        else:\n            out = self.up1(x5, x4)\n            out = self.up2(out, x3)\n            out = self.up3(out, x2)\n\n            out = self.up4(out, x1)\n            out = self.outc(out)\n\n            return out\n\n\n\n        \n\n\n\nclass UTNet_Encoderonly(nn.Module):\n    \n    def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):\n        super().__init__()\n\n        self.aux_loss = aux_loss\n        \n        self.inc = [BasicBlock(in_chan, base_chan)]\n        if '0' in block_list:\n            self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))\n        \n        else:\n            self.inc.append(BasicBlock(base_chan, base_chan))\n        self.inc = nn.Sequential(*self.inc)\n\n\n        if '1' in block_list:\n            self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)\n\n\n        if '2' in block_list:\n            self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)\n\n\n        if '3' in block_list:\n            self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)\n\n        if '4' in block_list:\n            self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)\n        else:\n            self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)\n\n\n        self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)\n        self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)\n        self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)\n        self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)\n\n        self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)\n\n        if aux_loss:\n            self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)\n            self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)\n            self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)\n            \n\n\n    def forward(self, x):\n        \n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        \n        if self.aux_loss:\n            out = self.up1(x5, x4)\n            out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up2(out, x3)\n            out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up3(out, x2)\n            out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)\n\n            out = self.up4(out, x1)\n            out = self.outc(out)\n\n            return out, out3, out2, out1\n\n        else:\n            out = self.up1(x5, x4)\n            out = self.up2(out, x3)\n            out = self.up3(out, x2)\n\n            out = self.up4(out, x1)\n            out = self.outc(out)\n\n            return out\n\n\n"
  },
  {
    "path": "train_deep.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import optim\n\nimport numpy as np\nfrom model.utnet import UTNet, UTNet_Encoderonly\n\nfrom dataset_domain import CMRDataset\n\nfrom torch.utils import data\nfrom losses import DiceLoss\nfrom utils.utils import *\nfrom utils import metrics\nfrom optparse import OptionParser\nimport SimpleITK as sitk\n\nfrom torch.utils.tensorboard import SummaryWriter\nimport time\nimport math\nimport os\nimport sys\nimport pdb\nimport warnings\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\nDEBUG = False\n\ndef train_net(net, options):\n    \n    data_path = options.data_path\n\n    trainset = CMRDataset(data_path, mode='train', domain=options.domain, debug=DEBUG, scale=options.scale, rotate=options.rotate, crop_size=options.crop_size)\n    trainLoader = data.DataLoader(trainset, batch_size=options.batch_size, shuffle=True, num_workers=16)\n\n    testset_A = CMRDataset(data_path, mode='test', domain='A', debug=DEBUG, crop_size=options.crop_size)\n    testLoader_A = data.DataLoader(testset_A, batch_size=1, shuffle=False, num_workers=2)\n    testset_B = CMRDataset(data_path, mode='test', domain='B', debug=DEBUG, crop_size=options.crop_size)\n    testLoader_B = data.DataLoader(testset_B, batch_size=1, shuffle=False, num_workers=2)\n    testset_C = CMRDataset(data_path, mode='test', domain='C', debug=DEBUG, crop_size=options.crop_size)\n    testLoader_C = data.DataLoader(testset_C, batch_size=1, shuffle=False, num_workers=2)\n    testset_D = CMRDataset(data_path, mode='test', domain='D', debug=DEBUG, crop_size=options.crop_size)\n    testLoader_D = data.DataLoader(testset_D, batch_size=1, shuffle=False, num_workers=2)\n\n\n\n\n\n    writer = SummaryWriter(options.log_path + options.unique_name)\n\n    optimizer = optim.SGD(net.parameters(), lr=options.lr, momentum=0.9, weight_decay=options.weight_decay)\n\n    criterion = nn.CrossEntropyLoss(weight=torch.tensor(options.weight).cuda())\n    criterion_dl = DiceLoss()\n\n\n    best_dice = 0\n    for epoch in range(options.epochs):\n        print('Starting epoch {}/{}'.format(epoch+1, options.epochs))\n        epoch_loss = 0\n\n        exp_scheduler = exp_lr_scheduler_with_warmup(optimizer, init_lr=options.lr, epoch=epoch, warmup_epoch=5, max_epoch=options.epochs)\n\n        print('current lr:', exp_scheduler)\n\n        for i, (img, label) in enumerate(trainLoader, 0):\n\n            img = img.cuda()\n            label = label.cuda()\n\n            end = time.time()\n            net.train()\n\n            optimizer.zero_grad()\n            \n            result = net(img)\n            \n            loss = 0\n            \n            if isinstance(result, tuple) or isinstance(result, list):\n                for j in range(len(result)):\n                    loss += options.aux_weight[j] * (criterion(result[j], label.squeeze(1)) + criterion_dl(result[j], label))\n            else:\n                loss = criterion(result, label.squeeze(1)) + criterion_dl(result, label)\n\n\n            loss.backward()\n            optimizer.step()\n\n\n            epoch_loss += loss.item()\n            batch_time = time.time() - end\n            print('batch loss: %.5f, batch_time:%.5f'%(loss.item(), batch_time))\n        print('[epoch %d] epoch loss: %.5f'%(epoch+1, epoch_loss/(i+1)))\n\n        writer.add_scalar('Train/Loss', epoch_loss/(i+1), epoch+1)\n        writer.add_scalar('LR', exp_scheduler, epoch+1)\n\n        if os.path.isdir('%s%s/'%(options.cp_path, options.unique_name)):\n            pass\n        else:\n            os.mkdir('%s%s/'%(options.cp_path, options.unique_name))\n\n        if epoch % 20 == 0 or epoch > options.epochs-10:\n            torch.save(net.state_dict(), '%s%s/CP%d.pth'%(options.cp_path, options.unique_name, epoch))\n        \n        if (epoch+1) >90 or (epoch+1) % 10 == 0:\n            dice_list_A, ASD_list_A, HD_list_A = validation(net, testLoader_A, options)\n            log_evaluation_result(writer, dice_list_A, ASD_list_A, HD_list_A, 'A', epoch)\n            \n            dice_list_B, ASD_list_B, HD_list_B = validation(net, testLoader_B, options)\n            log_evaluation_result(writer, dice_list_B, ASD_list_B, HD_list_B, 'B', epoch)\n\n            dice_list_C, ASD_list_C, HD_list_C = validation(net, testLoader_C, options)\n            log_evaluation_result(writer, dice_list_C, ASD_list_C, HD_list_C, 'C', epoch)\n\n            dice_list_D, ASD_list_D, HD_list_D = validation(net, testLoader_D, options)\n            log_evaluation_result(writer, dice_list_D, ASD_list_D, HD_list_D, 'D', epoch)\n\n\n            AVG_dice_list = 20 * dice_list_A + 50 * dice_list_B + 50 * dice_list_C + 50 * dice_list_D\n            AVG_dice_list /= 170\n\n            AVG_ASD_list = 20 * ASD_list_A + 50 * ASD_list_B + 50 * ASD_list_C + 50 * ASD_list_D\n            AVG_ASD_list /= 170\n\n            AVG_HD_list = 20 * HD_list_A + 50 * HD_list_B + 50 * HD_list_C + 50 * HD_list_D\n            AVG_HD_list /= 170\n\n            log_evaluation_result(writer, AVG_dice_list, AVG_ASD_list, AVG_HD_list, 'mean', epoch)\n\n\n\n            if dice_list_A.mean() >= best_dice:\n                best_dice = dice_list_A.mean()\n                torch.save(net.state_dict(), '%s%s/best.pth'%(options.cp_path, options.unique_name))\n\n            print('save done')\n            print('dice: %.5f/best dice: %.5f'%(dice_list_A.mean(), best_dice))\n\n\ndef validation(net, test_loader, options):\n\n    net.eval()\n    \n    dice_list = np.zeros(3)\n    ASD_list = np.zeros(3)\n    HD_list = np.zeros(3)\n\n    counter = 0\n    with torch.no_grad():\n        for i, (data, label, spacing) in enumerate(test_loader):\n            \n            inputs, labels = data.float().cuda(), label.long().cuda()\n            inputs = inputs.permute(1, 0, 2, 3)\n            labels = labels.permute(1, 0, 2, 3)\n    \n            pred = net(inputs)\n            if options.model == 'FCN_Res50' or options.model == 'FCN_Res101':\n                pred = pred['out']\n            elif isinstance(pred, tuple):\n                pred = pred[0]\n            pred = F.softmax(pred, dim=1)\n    \n            _, label_pred = torch.max(pred, dim=1)\n            \n            tmp_ASD_list, tmp_HD_list = cal_distance(label_pred, labels, spacing)\n            ASD_list += np.clip(np.nan_to_num(tmp_ASD_list, nan=500), 0, 500)\n            HD_list += np.clip(np.nan_to_num(tmp_HD_list, nan=500), 0, 500)\n            \n            label_pred = label_pred.view(-1, 1)\n            label_true = labels.view(-1, 1)\n\n            dice, _, _ = cal_dice(label_pred, label_true, 4)\n        \n            dice_list += dice.cpu().numpy()[1:]\n            counter += 1\n    dice_list /= counter\n    avg_dice = dice_list.mean()\n    ASD_list /= counter\n    HD_list /= counter\n\n    return dice_list , ASD_list, HD_list\n\n\n\n\ndef cal_distance(label_pred, label_true, spacing):\n    label_pred = label_pred.squeeze(1).cpu().numpy()\n    label_true = label_true.squeeze(1).cpu().numpy()\n    spacing = spacing.numpy()[0]\n\n    ASD_list = np.zeros(3)\n    HD_list = np.zeros(3)\n\n    for i in range(3):\n        tmp_surface = metrics.compute_surface_distances(label_true==(i+1), label_pred==(i+1), spacing)\n        dis_gt_to_pred, dis_pred_to_gt = metrics.compute_average_surface_distance(tmp_surface)\n        ASD_list[i] = (dis_gt_to_pred + dis_pred_to_gt) / 2\n\n        HD = metrics.compute_robust_hausdorff(tmp_surface, 100)\n        HD_list[i] = HD\n\n    return ASD_list, HD_list\n\n\n\n\nif __name__ == '__main__':\n    parser = OptionParser()\n    def get_comma_separated_int_args(option, opt, value, parser):\n        value_list = value.split(',')\n        value_list = [int(i) for i in value_list]\n        setattr(parser.values, option.dest, value_list)\n\n    parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', help='number of epochs')\n    parser.add_option('-b', '--batch_size', dest='batch_size', default=32, type='int', help='batch size')\n    parser.add_option('-l', '--learning-rate', dest='lr', default=0.05, type='float', help='learning rate')\n    parser.add_option('-c', '--resume', type='str', dest='load', default=False, help='load pretrained model')\n    parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', default='./checkpoint/', help='checkpoint path')\n    parser.add_option('--data_path', type='str', dest='data_path', default='/research/cbim/vast/yg397/vision_transformer/dataset/resampled_dataset/', help='dataset path')\n\n    parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/', help='log path')\n    parser.add_option('-m', type='str', dest='model', default='UTNet', help='use which model')\n    parser.add_option('--num_class', type='int', dest='num_class', default=4, help='number of segmentation classes')\n    parser.add_option('--base_chan', type='int', dest='base_chan', default=32, help='number of channels of first expansion in UNet')\n    parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test', help='unique experiment name')\n    parser.add_option('--rlt', type='float', dest='rlt', default=1, help='relation between CE/FL and dice')\n    parser.add_option('--weight', type='float', dest='weight',\n                      default=[0.5,1,1,1] , help='weight each class in loss function')\n    parser.add_option('--weight_decay', type='float', dest='weight_decay',\n                      default=0.0001)\n    parser.add_option('--scale', type='float', dest='scale', default=0.30)\n    parser.add_option('--rotate', type='float', dest='rotate', default=180)\n    parser.add_option('--crop_size', type='int', dest='crop_size', default=256)\n    parser.add_option('--domain', type='str', dest='domain', default='A')\n    parser.add_option('--aux_weight', type='float', dest='aux_weight', default=[1, 0.4, 0.2, 0.1])\n    parser.add_option('--reduce_size', dest='reduce_size', default=8, type='int')\n    parser.add_option('--block_list', dest='block_list', default='1234', type='str')\n    parser.add_option('--num_blocks', dest='num_blocks', default=[1,1,1,1], type='string', action='callback', callback=get_comma_separated_int_args)\n    parser.add_option('--aux_loss', dest='aux_loss', action='store_true', help='using aux loss for deep supervision')\n\n    \n    parser.add_option('--gpu', type='str', dest='gpu', default='0')\n    options, args = parser.parse_args()\n\n    os.environ['CUDA_VISIBLE_DEVICES'] = options.gpu\n\n    print('Using model:', options.model)\n\n    if options.model == 'UTNet':\n        net = UTNet(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)\n    elif options.model == 'UTNet_encoder':\n        # Apply transformer blocks only in the encoder\n        net = UTNet_Encoderonly(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)\n    elif options.model =='TransUNet':\n        from model.transunet import VisionTransformer as ViT_seg\n        from model.transunet import CONFIGS as CONFIGS_ViT_seg\n        config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']\n        config_vit.n_classes = 4 \n        config_vit.n_skip = 3 \n        config_vit.patches.grid = (int(256/16), int(256/16))\n        net = ViT_seg(config_vit, img_size=256, num_classes=4)\n        #net.load_from(weights=np.load('./initmodel/R50+ViT-B_16.npz')) # uncomment this to use pretrain model download from TransUnet git repo\n\n    elif options.model == 'ResNet_UTNet':\n        from model.resnet_utnet import ResNet_UTNet\n        net = ResNet_UTNet(1, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True)\n    \n    elif options.model == 'SwinUNet':\n        from model.swin_unet import SwinUnet, SwinUnet_config\n        config = SwinUnet_config()\n        net = SwinUnet(config, img_size=224, num_classes=options.num_class)\n        net.load_from('./initmodel/swin_tiny_patch4_window7_224.pth')\n\n\n    else:\n        raise NotImplementedError(options.model + \" has not been implemented\")\n    if options.load:\n        net.load_state_dict(torch.load(options.load))\n        print('Model loaded from {}'.format(options.load))\n    \n    param_num = sum(p.numel() for p in net.parameters() if p.requires_grad)\n    \n    print(net)\n    print(param_num)\n    \n    net.cuda()\n    \n    train_net(net, options)\n\n    print('done')\n\n    sys.exit(0)\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/lookup_tables.py",
    "content": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n\n# you may not use this file except in compliance with the License.\n\n# You may obtain a copy of the License at\n\n#\n\n#      http://www.apache.org/licenses/LICENSE-2.0\n\n#\n\n# Unless required by applicable law or agreed to in writing, software\n\n# distributed under the License is distributed on an \"AS-IS\" BASIS,\n\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n\n# See the License for the specific language governing permissions and\n\n# limitations under the License.\n\n\"\"\"Lookup tables used by surface distance metrics.\"\"\"\n\n\n\nfrom __future__ import absolute_import\n\nfrom __future__ import division\n\nfrom __future__ import print_function\n\n\n\nimport math\n\nimport numpy as np\n\n\n\nENCODE_NEIGHBOURHOOD_3D_KERNEL = np.array([[[128, 64], [32, 16]], [[8, 4],\n\n                                                                   [2, 1]]])\n\n\n\n# _NEIGHBOUR_CODE_TO_NORMALS is a lookup table.\n\n# For every binary neighbour code\n\n# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)\n\n# it contains the surface normals of the triangles (called \"surfel\" for\n\n# \"surface element\" in the following). The length of the normal\n\n# vector encodes the surfel area.\n\n#\n\n# created using the marching_cube algorithm\n\n# see e.g. https://en.wikipedia.org/wiki/Marching_cubes\n\n# pylint: disable=line-too-long\n\n_NEIGHBOUR_CODE_TO_NORMALS = [\n\n    [[0, 0, 0]],\n\n    [[0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n\n    [[0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],\n\n    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],\n\n    [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],\n\n    [[0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n\n    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],\n\n    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],\n\n    [[0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n\n    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],\n\n    [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],\n\n    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],\n\n    [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],\n\n    [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],\n\n    [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],\n\n    [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],\n\n    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],\n\n    [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n\n    [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n\n    [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],\n\n    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n\n    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],\n\n    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],\n\n    [[-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],\n\n    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],\n\n    [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],\n\n    [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],\n\n    [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n\n    [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],\n\n    [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],\n\n    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n\n    [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n\n    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n\n    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],\n\n    [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n\n    [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n\n    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n\n    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n\n    [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n\n    [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n\n    [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n\n    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n\n    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n\n    [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],\n\n    [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n\n    [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],\n\n    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],\n\n    [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n\n    [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],\n\n    [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],\n\n    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n\n    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n\n    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],\n\n    [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n\n    [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n\n    [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n\n    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],\n\n    [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],\n\n    [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],\n\n    [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],\n\n    [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],\n\n    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],\n\n    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],\n\n    [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],\n\n    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n\n    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125]],\n\n    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],\n\n    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],\n\n    [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n\n    [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],\n\n    [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],\n\n    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],\n\n    [[0.125, -0.125, -0.125]],\n\n    [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],\n\n    [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],\n\n    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],\n\n    [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n\n    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],\n\n    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],\n\n    [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],\n\n    [[-0.125, 0.125, 0.125]],\n\n    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n\n    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],\n\n    [[0.125, -0.125, 0.125]],\n\n    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n\n    [[-0.125, -0.125, 0.125]],\n\n    [[0.125, 0.125, 0.125]],\n\n    [[0, 0, 0]]]\n\n# pylint: enable=line-too-long\n\n\n\n\n\ndef create_table_neighbour_code_to_surface_area(spacing_mm):\n\n  \"\"\"Returns an array mapping neighbourhood code to the surface elements area.\n\n  Note that the normals encode the initial surface area. This function computes\n\n  the area corresponding to the given `spacing_mm`.\n\n  Args:\n\n    spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2\n\n      direction.\n\n  \"\"\"\n\n  # compute the area for all 256 possible surface elements\n\n  # (given a 2x2x2 neighbourhood) according to the spacing_mm\n\n  neighbour_code_to_surface_area = np.zeros([256])\n\n  for code in range(256):\n\n    normals = np.array(_NEIGHBOUR_CODE_TO_NORMALS[code])\n\n    sum_area = 0\n\n    for normal_idx in range(normals.shape[0]):\n\n      # normal vector\n\n      n = np.zeros([3])\n\n      n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2]\n\n      n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2]\n\n      n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1]\n\n      area = np.linalg.norm(n)\n\n      sum_area += area\n\n    neighbour_code_to_surface_area[code] = sum_area\n\n\n\n  return neighbour_code_to_surface_area\n\n\n\n\n\n# In the neighbourhood, points are ordered: top left, top right, bottom left,\n\n# bottom right.\n\nENCODE_NEIGHBOURHOOD_2D_KERNEL = np.array([[8, 4], [2, 1]])\n\n\n\n\n\ndef create_table_neighbour_code_to_contour_length(spacing_mm):\n\n  \"\"\"Returns an array mapping neighbourhood code to the contour length.\n\n  For the list of possible cases and their figures, see page 38 from:\n\n  https://nccastaff.bournemouth.ac.uk/jmacey/MastersProjects/MSc14/06/thesis.pdf\n\n  In 2D, each point has 4 neighbors. Thus, are 16 configurations. A\n\n  configuration is encoded with '1' meaning \"inside the object\" and '0' \"outside\n\n  the object\". The points are ordered: top left, top right, bottom left, bottom\n\n  right.\n\n  The x0 axis is assumed vertical downward, and the x1 axis is horizontal to the\n\n  right:\n\n   (0, 0) --> (0, 1)\n\n     |\n\n   (1, 0)\n\n  Args:\n\n    spacing_mm: 2-element list-like structure. Voxel spacing in x0 and x1\n\n      directions.\n\n  \"\"\"\n\n  neighbour_code_to_contour_length = np.zeros([16])\n\n\n\n  vertical = spacing_mm[0]\n\n  horizontal = spacing_mm[1]\n\n  diag = 0.5 * math.sqrt(spacing_mm[0]**2 + spacing_mm[1]**2)\n\n  # pyformat: disable\n\n  neighbour_code_to_contour_length[int(\"00\"\n\n                                       \"01\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"00\"\n\n                                       \"10\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"00\"\n\n                                       \"11\", 2)] = horizontal\n\n\n\n  neighbour_code_to_contour_length[int(\"01\"\n\n                                       \"00\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"01\"\n\n                                       \"01\", 2)] = vertical\n\n\n\n  neighbour_code_to_contour_length[int(\"01\"\n\n                                       \"10\", 2)] = 2*diag\n\n\n\n  neighbour_code_to_contour_length[int(\"01\"\n\n                                       \"11\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"10\"\n\n                                       \"00\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"10\"\n\n                                       \"01\", 2)] = 2*diag\n\n\n\n  neighbour_code_to_contour_length[int(\"10\"\n\n                                       \"10\", 2)] = vertical\n\n\n\n  neighbour_code_to_contour_length[int(\"10\"\n\n                                       \"11\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"11\"\n\n                                       \"00\", 2)] = horizontal\n\n\n\n  neighbour_code_to_contour_length[int(\"11\"\n\n                                       \"01\", 2)] = diag\n\n\n\n  neighbour_code_to_contour_length[int(\"11\"\n\n                                       \"10\", 2)] = diag\n\n  # pyformat: enable\n\n\n\n  return neighbour_code_to_contour_length\n"
  },
  {
    "path": "utils/metrics.py",
    "content": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n\n# you may not use this file except in compliance with the License.\n\n# You may obtain a copy of the License at\n\n#\n\n#      http://www.apache.org/licenses/LICENSE-2.0\n\n#\n\n# Unless required by applicable law or agreed to in writing, software\n\n# distributed under the License is distributed on an \"AS-IS\" BASIS,\n\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n\n# See the License for the specific language governing permissions and\n\n# limitations under the License.\n\n\"\"\"Module exposing surface distance based measures.\"\"\"\n\n\n\nfrom __future__ import absolute_import\n\nfrom __future__ import division\n\nfrom __future__ import print_function\n\n\n\nfrom . import lookup_tables  # pylint: disable=relative-beyond-top-level\n\nimport numpy as np\n\nfrom scipy import ndimage\n\n\n\n\n\ndef _assert_is_numpy_array(name, array):\n\n  \"\"\"Raises an exception if `array` is not a numpy array.\"\"\"\n\n  if not isinstance(array, np.ndarray):\n\n    raise ValueError(\"The argument {!r} should be a numpy array, not a \"\n\n                     \"{}\".format(name, type(array)))\n\n\n\n\n\ndef _check_nd_numpy_array(name, array, num_dims):\n\n  \"\"\"Raises an exception if `array` is not a `num_dims`-D numpy array.\"\"\"\n\n  if len(array.shape) != num_dims:\n\n    raise ValueError(\"The argument {!r} should be a {}D array, not of \"\n\n                     \"shape {}\".format(name, num_dims, array.shape))\n\n\n\n\n\ndef _check_2d_numpy_array(name, array):\n\n  _check_nd_numpy_array(name, array, num_dims=2)\n\n\n\n\n\ndef _check_3d_numpy_array(name, array):\n\n  _check_nd_numpy_array(name, array, num_dims=3)\n\n\n\n\n\ndef _assert_is_bool_numpy_array(name, array):\n\n  _assert_is_numpy_array(name, array)\n\n  if array.dtype != np.bool:\n\n    raise ValueError(\"The argument {!r} should be a numpy array of type bool, \"\n\n                     \"not {}\".format(name, array.dtype))\n\n\n\n\n\ndef _compute_bounding_box(mask):\n\n  \"\"\"Computes the bounding box of the masks.\n\n  This function generalizes to arbitrary number of dimensions great or equal\n\n  to 1.\n\n  Args:\n\n    mask: The 2D or 3D numpy mask, where '0' means background and non-zero means\n\n      foreground.\n\n  Returns:\n\n    A tuple:\n\n     - The coordinates of the first point of the bounding box (smallest on all\n\n       axes), or `None` if the mask contains only zeros.\n\n     - The coordinates of the second point of the bounding box (greatest on all\n\n       axes), or `None` if the mask contains only zeros.\n\n  \"\"\"\n\n  num_dims = len(mask.shape)\n\n  bbox_min = np.zeros(num_dims, np.int64)\n\n  bbox_max = np.zeros(num_dims, np.int64)\n\n\n\n  # max projection to the x0-axis\n\n  proj_0 = np.amax(mask, axis=tuple(range(num_dims))[1:])\n\n  idx_nonzero_0 = np.nonzero(proj_0)[0]\n\n  if len(idx_nonzero_0) == 0:  # pylint: disable=g-explicit-length-test\n\n    return None, None\n\n\n\n  bbox_min[0] = np.min(idx_nonzero_0)\n\n  bbox_max[0] = np.max(idx_nonzero_0)\n\n\n\n  # max projection to the i-th-axis for i in {1, ..., num_dims - 1}\n\n  for axis in range(1, num_dims):\n\n    max_over_axes = list(range(num_dims))  # Python 3 compatible\n\n    max_over_axes.pop(axis)  # Remove the i-th dimension from the max\n\n    max_over_axes = tuple(max_over_axes)  # numpy expects a tuple of ints\n\n    proj = np.amax(mask, axis=max_over_axes)\n\n    idx_nonzero = np.nonzero(proj)[0]\n\n    bbox_min[axis] = np.min(idx_nonzero)\n\n    bbox_max[axis] = np.max(idx_nonzero)\n\n\n\n  return bbox_min, bbox_max\n\n\n\n\n\ndef _crop_to_bounding_box(mask, bbox_min, bbox_max):\n\n  \"\"\"Crops a 2D or 3D mask to the bounding box specified by `bbox_{min,max}`.\"\"\"\n\n  # we need to zeropad the cropped region with 1 voxel at the lower,\n\n  # the right (and the back on 3D) sides. This is required to obtain the\n\n  # \"full\" convolution result with the 2x2 (or 2x2x2 in 3D) kernel.\n\n  #TODO:  This is correct only if the object is interior to the\n\n  # bounding box.\n\n  cropmask = np.zeros((bbox_max - bbox_min) + 2, np.uint8)\n\n\n\n  num_dims = len(mask.shape)\n\n  # pyformat: disable\n\n  if num_dims == 2:\n\n    cropmask[0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,\n\n                                bbox_min[1]:bbox_max[1] + 1]\n\n  elif num_dims == 3:\n\n    cropmask[0:-1, 0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,\n\n                                      bbox_min[1]:bbox_max[1] + 1,\n\n                                      bbox_min[2]:bbox_max[2] + 1]\n\n  # pyformat: enable\n\n  else:\n\n    assert False\n\n\n\n  return cropmask\n\n\n\n\n\ndef _sort_distances_surfels(distances, surfel_areas):\n\n  \"\"\"Sorts the two list with respect to the tuple of (distance, surfel_area).\n\n\n\n  Args:\n\n    distances: The distances from A to B (e.g. `distances_gt_to_pred`).\n\n    surfel_areas: The surfel areas for A (e.g. `surfel_areas_gt`).\n\n\n\n  Returns:\n\n    A tuple of the sorted (distances, surfel_areas).\n\n  \"\"\"\n\n  sorted_surfels = np.array(sorted(zip(distances, surfel_areas)))\n\n  return sorted_surfels[:, 0], sorted_surfels[:, 1]\n\n\n\n\n\ndef compute_surface_distances(mask_gt,\n\n                              mask_pred,\n\n                              spacing_mm):\n\n  \"\"\"Computes closest distances from all surface points to the other surface.\n\n\n\n  This function can be applied to 2D or 3D tensors. For 2D, both masks must be\n\n  2D and `spacing_mm` must be a 2-element list. For 3D, both masks must be 3D\n\n  and `spacing_mm` must be a 3-element list. The description is done for the 2D\n\n  case, and the formulation for the 3D case is present is parenthesis,\n\n  introduced by \"resp.\".\n\n\n\n  Finds all contour elements (resp surface elements \"surfels\" in 3D) in the\n\n  ground truth mask `mask_gt` and the predicted mask `mask_pred`, computes their\n\n  length in mm (resp. area in mm^2) and the distance to the closest point on the\n\n  other contour (resp. surface). It returns two sorted lists of distances\n\n  together with the corresponding contour lengths (resp. surfel areas). If one\n\n  of the masks is empty, the corresponding lists are empty and all distances in\n\n  the other list are `inf`.\n\n\n\n  Args:\n\n    mask_gt: 2-dim (resp. 3-dim) bool Numpy array. The ground truth mask.\n\n    mask_pred: 2-dim (resp. 3-dim) bool Numpy array. The predicted mask.\n\n    spacing_mm: 2-element (resp. 3-element) list-like structure. Voxel spacing\n\n      in x0 anx x1 (resp. x0, x1 and x2) directions.\n\n\n\n  Returns:\n\n    A dict with:\n\n    \"distances_gt_to_pred\": 1-dim numpy array of type float. The distances in mm\n\n        from all ground truth surface elements to the predicted surface,\n\n        sorted from smallest to largest.\n\n    \"distances_pred_to_gt\": 1-dim numpy array of type float. The distances in mm\n\n        from all predicted surface elements to the ground truth surface,\n\n        sorted from smallest to largest.\n\n    \"surfel_areas_gt\": 1-dim numpy array of type float. The length of the\n\n      of the ground truth contours in mm (resp. the surface elements area in\n\n      mm^2) in the same order as distances_gt_to_pred.\n\n    \"surfel_areas_pred\": 1-dim numpy array of type float. The length of the\n\n      of the predicted contours in mm (resp. the surface elements area in\n\n      mm^2) in the same order as distances_gt_to_pred.\n\n\n\n  Raises:\n\n    ValueError: If the masks and the `spacing_mm` arguments are of incompatible\n\n      shape or type. Or if the masks are not 2D or 3D.\n\n  \"\"\"\n\n  # The terms used in this function are for the 3D case. In particular, surface\n\n  # in 2D stands for contours in 3D. The surface elements in 3D correspond to\n\n  # the line elements in 2D.\n\n\n\n  _assert_is_bool_numpy_array(\"mask_gt\", mask_gt)\n\n  _assert_is_bool_numpy_array(\"mask_pred\", mask_pred)\n\n\n\n  if not len(mask_gt.shape) == len(mask_pred.shape) == len(spacing_mm):\n\n    raise ValueError(\"The arguments must be of compatible shape. Got mask_gt \"\n\n                     \"with {} dimensions ({}) and mask_pred with {} dimensions \"\n\n                     \"({}), while the spacing_mm was {} elements.\".format(\n\n                         len(mask_gt.shape),\n\n                         mask_gt.shape, len(mask_pred.shape), mask_pred.shape,\n\n                         len(spacing_mm)))\n\n\n\n  num_dims = len(spacing_mm)\n\n  if num_dims == 2:\n\n    _check_2d_numpy_array(\"mask_gt\", mask_gt)\n\n    _check_2d_numpy_array(\"mask_pred\", mask_pred)\n\n\n\n    # compute the area for all 16 possible surface elements\n\n    # (given a 2x2 neighbourhood) according to the spacing_mm\n\n    neighbour_code_to_surface_area = (\n\n        lookup_tables.create_table_neighbour_code_to_contour_length(spacing_mm))\n\n    kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_2D_KERNEL\n\n    full_true_neighbours = 0b1111\n\n  elif num_dims == 3:\n\n    _check_3d_numpy_array(\"mask_gt\", mask_gt)\n\n    _check_3d_numpy_array(\"mask_pred\", mask_pred)\n\n\n\n    # compute the area for all 256 possible surface elements\n\n    # (given a 2x2x2 neighbourhood) according to the spacing_mm\n\n    neighbour_code_to_surface_area = (\n\n        lookup_tables.create_table_neighbour_code_to_surface_area(spacing_mm))\n\n    kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_3D_KERNEL\n\n    full_true_neighbours = 0b11111111\n\n  else:\n\n    raise ValueError(\"Only 2D and 3D masks are supported, not \"\n\n                     \"{}D.\".format(num_dims))\n\n\n\n  # compute the bounding box of the masks to trim the volume to the smallest\n\n  # possible processing subvolume\n\n  bbox_min, bbox_max = _compute_bounding_box(mask_gt | mask_pred)\n\n  # Both the min/max bbox are None at the same time, so we only check one.\n\n  if bbox_min is None:\n\n    return {\n\n        \"distances_gt_to_pred\": np.array([]),\n\n        \"distances_pred_to_gt\": np.array([]),\n\n        \"surfel_areas_gt\": np.array([]),\n\n        \"surfel_areas_pred\": np.array([]),\n\n    }\n\n\n\n  # crop the processing subvolume.\n\n  cropmask_gt = _crop_to_bounding_box(mask_gt, bbox_min, bbox_max)\n\n  cropmask_pred = _crop_to_bounding_box(mask_pred, bbox_min, bbox_max)\n\n\n\n  # compute the neighbour code (local binary pattern) for each voxel\n\n  # the resulting arrays are spacially shifted by minus half a voxel in each\n\n  # axis.\n\n  # i.e. the points are located at the corners of the original voxels\n\n  neighbour_code_map_gt = ndimage.filters.correlate(\n\n      cropmask_gt.astype(np.uint8), kernel, mode=\"constant\", cval=0)\n\n  neighbour_code_map_pred = ndimage.filters.correlate(\n\n      cropmask_pred.astype(np.uint8), kernel, mode=\"constant\", cval=0)\n\n\n\n  # create masks with the surface voxels\n\n  borders_gt = ((neighbour_code_map_gt != 0) &\n\n                (neighbour_code_map_gt != full_true_neighbours))\n\n  borders_pred = ((neighbour_code_map_pred != 0) &\n\n                  (neighbour_code_map_pred != full_true_neighbours))\n\n\n\n  # compute the distance transform (closest distance of each voxel to the\n\n  # surface voxels)\n\n  if borders_gt.any():\n\n    distmap_gt = ndimage.morphology.distance_transform_edt(\n\n        ~borders_gt, sampling=spacing_mm)\n\n  else:\n\n    distmap_gt = np.Inf * np.ones(borders_gt.shape)\n\n\n\n  if borders_pred.any():\n\n    distmap_pred = ndimage.morphology.distance_transform_edt(\n\n        ~borders_pred, sampling=spacing_mm)\n\n  else:\n\n    distmap_pred = np.Inf * np.ones(borders_pred.shape)\n\n\n\n  # compute the area of each surface element\n\n  surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]\n\n  surface_area_map_pred = neighbour_code_to_surface_area[\n\n      neighbour_code_map_pred]\n\n\n\n  # create a list of all surface elements with distance and area\n\n  distances_gt_to_pred = distmap_pred[borders_gt]\n\n  distances_pred_to_gt = distmap_gt[borders_pred]\n\n  surfel_areas_gt = surface_area_map_gt[borders_gt]\n\n  surfel_areas_pred = surface_area_map_pred[borders_pred]\n\n\n\n  # sort them by distance\n\n  if distances_gt_to_pred.shape != (0,):\n\n    distances_gt_to_pred, surfel_areas_gt = _sort_distances_surfels(\n\n        distances_gt_to_pred, surfel_areas_gt)\n\n\n\n  if distances_pred_to_gt.shape != (0,):\n\n    distances_pred_to_gt, surfel_areas_pred = _sort_distances_surfels(\n\n        distances_pred_to_gt, surfel_areas_pred)\n\n\n\n  return {\n\n      \"distances_gt_to_pred\": distances_gt_to_pred,\n\n      \"distances_pred_to_gt\": distances_pred_to_gt,\n\n      \"surfel_areas_gt\": surfel_areas_gt,\n\n      \"surfel_areas_pred\": surfel_areas_pred,\n\n  }\n\n\n\n\n\ndef compute_average_surface_distance(surface_distances):\n\n  \"\"\"Returns the average surface distance.\n\n\n\n  Computes the average surface distances by correctly taking the area of each\n\n  surface element into account. Call compute_surface_distances(...) before, to\n\n  obtain the `surface_distances` dict.\n\n\n\n  Args:\n\n    surface_distances: dict with \"distances_gt_to_pred\", \"distances_pred_to_gt\"\n\n    \"surfel_areas_gt\", \"surfel_areas_pred\" created by\n\n    compute_surface_distances()\n\n\n\n  Returns:\n\n    A tuple with two float values:\n\n      - the average distance (in mm) from the ground truth surface to the\n\n        predicted surface\n\n      - the average distance from the predicted surface to the ground truth\n\n        surface.\n\n  \"\"\"\n\n  distances_gt_to_pred = surface_distances[\"distances_gt_to_pred\"]\n\n  distances_pred_to_gt = surface_distances[\"distances_pred_to_gt\"]\n\n  surfel_areas_gt = surface_distances[\"surfel_areas_gt\"]\n\n  surfel_areas_pred = surface_distances[\"surfel_areas_pred\"]\n\n  average_distance_gt_to_pred = (\n\n      np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt))\n\n  average_distance_pred_to_gt = (\n\n      np.sum(distances_pred_to_gt * surfel_areas_pred) /\n\n      np.sum(surfel_areas_pred))\n\n  return (average_distance_gt_to_pred, average_distance_pred_to_gt)\n\n\n\n\n\ndef compute_robust_hausdorff(surface_distances, percent):\n\n  \"\"\"Computes the robust Hausdorff distance.\n\n\n\n  Computes the robust Hausdorff distance. \"Robust\", because it uses the\n\n  `percent` percentile of the distances instead of the maximum distance. The\n\n  percentage is computed by correctly taking the area of each surface element\n\n  into account.\n\n\n\n  Args:\n\n    surface_distances: dict with \"distances_gt_to_pred\", \"distances_pred_to_gt\"\n\n      \"surfel_areas_gt\", \"surfel_areas_pred\" created by\n\n      compute_surface_distances()\n\n    percent: a float value between 0 and 100.\n\n\n\n  Returns:\n\n    a float value. The robust Hausdorff distance in mm.\n\n  \"\"\"\n\n  distances_gt_to_pred = surface_distances[\"distances_gt_to_pred\"]\n\n  distances_pred_to_gt = surface_distances[\"distances_pred_to_gt\"]\n\n  surfel_areas_gt = surface_distances[\"surfel_areas_gt\"]\n\n  surfel_areas_pred = surface_distances[\"surfel_areas_pred\"]\n\n  if len(distances_gt_to_pred) > 0:  # pylint: disable=g-explicit-length-test\n\n    surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)\n\n    idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)\n\n    perc_distance_gt_to_pred = distances_gt_to_pred[\n\n        min(idx, len(distances_gt_to_pred)-1)]\n\n  else:\n\n    perc_distance_gt_to_pred = np.Inf\n\n\n\n  if len(distances_pred_to_gt) > 0:  # pylint: disable=g-explicit-length-test\n\n    surfel_areas_cum_pred = (np.cumsum(surfel_areas_pred) /\n\n                             np.sum(surfel_areas_pred))\n\n    idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)\n\n    perc_distance_pred_to_gt = distances_pred_to_gt[\n\n        min(idx, len(distances_pred_to_gt)-1)]\n\n  else:\n\n    perc_distance_pred_to_gt = np.Inf\n\n\n\n  return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt)\n\n\n\n\n\ndef compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):\n\n  \"\"\"Computes the overlap of the surfaces at a specified tolerance.\n\n\n\n  Computes the overlap of the ground truth surface with the predicted surface\n\n  and vice versa allowing a specified tolerance (maximum surface-to-surface\n\n  distance that is regarded as overlapping). The overlapping fraction is\n\n  computed by correctly taking the area of each surface element into account.\n\n\n\n  Args:\n\n    surface_distances: dict with \"distances_gt_to_pred\", \"distances_pred_to_gt\"\n\n      \"surfel_areas_gt\", \"surfel_areas_pred\" created by\n\n      compute_surface_distances()\n\n    tolerance_mm: a float value. The tolerance in mm\n\n\n\n  Returns:\n\n    A tuple of two float values. The overlap fraction in [0.0, 1.0] of the\n\n    ground truth surface with the predicted surface and vice versa.\n\n  \"\"\"\n\n  distances_gt_to_pred = surface_distances[\"distances_gt_to_pred\"]\n\n  distances_pred_to_gt = surface_distances[\"distances_pred_to_gt\"]\n\n  surfel_areas_gt = surface_distances[\"surfel_areas_gt\"]\n\n  surfel_areas_pred = surface_distances[\"surfel_areas_pred\"]\n\n  rel_overlap_gt = (\n\n      np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) /\n\n      np.sum(surfel_areas_gt))\n\n  rel_overlap_pred = (\n\n      np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) /\n\n      np.sum(surfel_areas_pred))\n\n  return (rel_overlap_gt, rel_overlap_pred)\n\n\n\n\n\ndef compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):\n\n  \"\"\"Computes the _surface_ DICE coefficient at a specified tolerance.\n\n\n\n  Computes the _surface_ DICE coefficient at a specified tolerance. Not to be\n\n  confused with the standard _volumetric_ DICE coefficient. The surface DICE\n\n  measures the overlap of two surfaces instead of two volumes. A surface\n\n  element is counted as overlapping (or touching), when the closest distance to\n\n  the other surface is less or equal to the specified tolerance. The DICE\n\n  coefficient is in the range between 0.0 (no overlap) to 1.0 (perfect overlap).\n\n\n\n  Args:\n\n    surface_distances: dict with \"distances_gt_to_pred\", \"distances_pred_to_gt\"\n\n      \"surfel_areas_gt\", \"surfel_areas_pred\" created by\n\n      compute_surface_distances()\n\n    tolerance_mm: a float value. The tolerance in mm\n\n\n\n  Returns:\n\n    A float value. The surface DICE coefficient in [0.0, 1.0].\n\n  \"\"\"\n\n  distances_gt_to_pred = surface_distances[\"distances_gt_to_pred\"]\n\n  distances_pred_to_gt = surface_distances[\"distances_pred_to_gt\"]\n\n  surfel_areas_gt = surface_distances[\"surfel_areas_gt\"]\n\n  surfel_areas_pred = surface_distances[\"surfel_areas_pred\"]\n\n  overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])\n\n  overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])\n\n  surface_dice = (overlap_gt + overlap_pred) / (\n\n      np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))\n\n  return surface_dice\n\n\n\n\n\ndef compute_dice_coefficient(mask_gt, mask_pred):\n\n  \"\"\"Computes soerensen-dice coefficient.\n\n\n\n  compute the soerensen-dice coefficient between the ground truth mask `mask_gt`\n\n  and the predicted mask `mask_pred`.\n\n\n\n  Args:\n\n    mask_gt: 3-dim Numpy array of type bool. The ground truth mask.\n\n    mask_pred: 3-dim Numpy array of type bool. The predicted mask.\n\n\n\n  Returns:\n\n    the dice coeffcient as float. If both masks are empty, the result is NaN.\n\n  \"\"\"\n\n  volume_sum = mask_gt.sum() + mask_pred.sum()\n\n  if volume_sum == 0:\n\n    return np.NaN\n\n  volume_intersect = (mask_gt & mask_pred).sum()\n\n  return 2*volume_intersect / volume_sum\n"
  },
  {
    "path": "utils/utils.py",
    "content": "import torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport SimpleITK as sitk\nimport pdb \n\n\ndef log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, epoch):\n    \n    writer.add_scalar('Test_Dice/%s_AVG'%name, dice_list.mean(), epoch+1)\n    for idx in range(3):\n        writer.add_scalar('Test_Dice/%s_Dice%d'%(name, idx+1), dice_list[idx], epoch+1)\n    writer.add_scalar('Test_ASD/%s_AVG'%name, ASD_list.mean(), epoch+1)\n    for idx in range(3):\n        writer.add_scalar('Test_ASD/%s_ASD%d'%(name, idx+1), ASD_list[idx], epoch+1)\n    writer.add_scalar('Test_HD/%s_AVG'%name, HD_list.mean(), epoch+1)\n    for idx in range(3):\n        writer.add_scalar('Test_HD/%s_HD%d'%(name, idx+1), HD_list[idx], epoch+1)\n\n\ndef multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1):\n\n    if epoch >= 0 and epoch <= warmup_epoch:\n        lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))\n        if epoch == warmup_epoch:\n            lr = init_lr\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n\n        return lr\n\n    flag = False\n    for i in range(len(lr_decay_epoch)):\n        if epoch == lr_decay_epoch[i]:\n            flag = True\n            break\n\n    if flag == True:\n        lr = init_lr * gamma**(i+1)\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n\n    else:\n        return optimizer.param_groups[0]['lr']\n\n    return lr\n\ndef exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch):\n\n    if epoch >= 0 and epoch <= warmup_epoch:\n        lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))\n        if epoch == warmup_epoch:\n            lr = init_lr\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n\n        return lr\n\n    else:\n        lr = init_lr * (1 - epoch / max_epoch)**0.9\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n\n    return lr\n\n\n\ndef cal_dice(pred, target, C): \n    N = pred.shape[0]\n    target_mask = target.data.new(N, C).fill_(0)\n    target_mask.scatter_(1, target, 1.) \n\n    pred_mask = pred.data.new(N, C).fill_(0)\n    pred_mask.scatter_(1, pred, 1.) \n\n    intersection= pred_mask * target_mask\n    summ = pred_mask + target_mask\n\n    intersection = intersection.sum(0).type(torch.float32)\n    summ = summ.sum(0).type(torch.float32)\n    \n    eps = torch.rand(C, dtype=torch.float32)\n    eps = eps.fill_(1e-7)\n\n    summ += eps.cuda()\n    dice = 2 * intersection / summ\n\n    return dice, intersection, summ\n\ndef cal_asd(itkPred, itkGT):\n    \n    reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkGT, squaredDistance=False))\n    reference_surface = sitk.LabelContour(itkGT)\n\n    statistics_image_filter = sitk.StatisticsImageFilter()\n    statistics_image_filter.Execute(reference_surface)\n    num_reference_surface_pixels = int(statistics_image_filter.GetSum())\n\n    segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkPred, squaredDistance=False))\n    segmented_surface = sitk.LabelContour(itkPred)\n\n    seg2ref_distance_map = reference_distance_map * sitk.Cast(segmented_surface, sitk.sitkFloat32)\n    ref2seg_distance_map = segmented_distance_map * sitk.Cast(reference_surface, sitk.sitkFloat32)\n\n    statistics_image_filter.Execute(segmented_surface)\n    num_segmented_surface_pixels = int(statistics_image_filter.GetSum())\n\n    seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)\n    seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0])\n    seg2ref_distances = seg2ref_distances + \\\n                        list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))\n    ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)\n    ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0])\n    ref2seg_distances = ref2seg_distances + \\\n                        list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))\n    \n    all_surface_distances = seg2ref_distances + ref2seg_distances\n\n    ASD = np.mean(all_surface_distances)\n\n    return ASD\n\n"
  }
]