Repository: yhygao/UTNet Branch: main Commit: 6bdcc97e2364 Files: 27 Total size: 190.5 KB Directory structure: gitextract_vjhzdth0/ ├── .gitignore ├── LICENSE ├── README.md ├── data_preprocess.py ├── dataset/ │ ├── Testing_A.csv │ ├── Testing_B.csv │ ├── Testing_C.csv │ ├── Testing_D.csv │ ├── Training_A.csv │ ├── Training_B.csv │ ├── Training_C.csv │ ├── Training_D.csv │ └── info.csv ├── dataset_domain.py ├── losses.py ├── model/ │ ├── __init__.py │ ├── conv_trans_utils.py │ ├── resnet_utnet.py │ ├── swin_unet.py │ ├── transunet.py │ ├── unet_utils.py │ └── utnet.py ├── train_deep.py └── utils/ ├── __init__.py ├── lookup_tables.py ├── metrics.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ *.pyc *_bkp.py checkpoint/ log/ model/backup/ initmodel/ *.swp ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 yhygao Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # UTNet (Accepted at MICCAI 2021) Official implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation](https://arxiv.org/abs/2107.00781) ## Update * Our new work, Hermes, has been released on arXiv: [Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning](https://arxiv.org/pdf/2306.02416.pdf). Inspired by the training of medical residents, we explore universal medical image segmentation, whose goal is to learn from diverse medical imaging sources covering a range of clinical targets, body regions, and image modalities. Following this paradigm, we propose Hermes, a context prior learning approach that addresses the challenges related to the heterogeneity on data, modality, and annotations in the proposed universal paradigm. Code will be released at https://github.com/yhygao/universal-medical-image-segmentation. * Our new paper, the improved version of UTNet: UTNetV2, is released on Arxiv: [A Multi-scale Transformer for Medical Image Segmentation: Architectures, Model Efficiency, and Benchmarks](https://arxiv.org/abs/2203.00131). The UTNetV2 has an improved architecture for 2D and 3D setting. We also provide a more general framework for CNN and Transformer comparison, including more dataset support, more SOTA model support, see in our new repo:https://github.com/yhygao/CBIM-Medical-Image-Segmentation Data preprocess code uploaded. ## Introduction Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations. ![image](https://user-images.githubusercontent.com/55367673/134997310-69c3576d-bbf2-40c8-ad5a-9b3bf3e9e97d.png) ![image](https://user-images.githubusercontent.com/55367673/134997347-a581cda7-7050-48ef-9af3-d4628fefac9a.png) ## Supportting models UTNet TransUNet ResNet50-UTNet ResNet50-UNet SwinUNet To be continue ... ## Getting Started Currently, we only support [M&Ms dataset](https://www.ub.edu/mnms/). ### Prerequisites ``` Python >= 3.6 pytorch = 1.8.1 SimpleITK = 2.0.2 numpy = 1.19.5 einops = 0.3.2 ``` ### Preprocess Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/' ### Training The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used. #### UTNet For default UTNet setting, training with: ``` python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss ``` Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases. To optimize UTNet in your own task, there are several hyperparameters to tune: '--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation. '--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level. '--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications. '--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance. Here are some recomended parameter setting: ``` --block_list 1234 --num_blocks 1,1,1,1 ``` Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important. ``` --block_list 1234 --num_blocks 1,1,4,8 ``` Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training. ``` --block_list 234 --num_blocks 2,4,8 ``` Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship. Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets. #### TransUNet We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the [original repo](https://github.com/Beckschen/TransUNet). The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py. ``` python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0 ``` #### ResNet50-UTNet For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset. ``` python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0 ``` Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data. ``` --block_list 23 --num_blocks 2,4 ``` Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship. #### ResNet50-UNet If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet. ``` python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list '' --gpu 0 ``` #### SwinUNet Download pre-trained model from the [origin repo](https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912). As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low. ``` python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224 ``` ## Citation If you find this repo helps, please kindly cite our paper, thanks! ``` @inproceedings{gao2021utnet, title={UTNet: a hybrid transformer architecture for medical image segmentation}, author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N}, booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, pages={61--71}, year={2021}, organization={Springer} } @misc{gao2022datascalable, title={A Data-scalable Transformer for Medical Image Segmentation: Architecture, Model Efficiency, and Benchmark}, author={Yunhe Gao and Mu Zhou and Di Liu and Zhennan Yan and Shaoting Zhang and Dimitris N. Metaxas}, year={2022}, eprint={2203.00131}, archivePrefix={arXiv}, primaryClass={eess.IV} } ``` ================================================ FILE: data_preprocess.py ================================================ import numpy as np import SimpleITK as sitk import os import pdb def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear): identity1 = sitk.Transform(3, sitk.sitkIdentity) sp1 = imImage.GetSpacing() sz1 = imImage.GetSize() sz2 = (int(round(sz1[0]*sp1[0]*1.0/space[0])), int(round(sz1[1]*sp1[1]*1.0/space[1])), int(round(sz1[2]*sp1[2]*1.0/space[2]))) imRefImage = sitk.Image(sz2, imImage.GetPixelIDValue()) imRefImage.SetSpacing(space) imRefImage.SetOrigin(imImage.GetOrigin()) imRefImage.SetDirection(imImage.GetDirection()) imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp) return imOutImage def ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighbor): identity1 = sitk.Transform(3, sitk.sitkIdentity) imRefImage = sitk.Image(imRef.GetSize(), imImage.GetPixelIDValue()) imRefImage.SetSpacing(imRef.GetSpacing()) imRefImage.SetOrigin(imRef.GetOrigin()) imRefImage.SetDirection(imRef.GetDirection()) imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp) return imOutImage def ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, target_space=(1., 1.)): assert imImage.GetSpacing() == imLabel.GetSpacing() assert imImage.GetSize() == imLabel.GetSize() spacing = imImage.GetSpacing() origin = imImage.GetOrigin() npimg = sitk.GetArrayFromImage(imImage) nplab = sitk.GetArrayFromImage(imLabel) t, z, y, x = npimg.shape if not os.path.exists('%s/%s'%(save_path, patient_name)): os.mkdir('%s/%s'%(save_path, patient_name)) flag = 0 for i in range(t): tmp_img = npimg[i] tmp_lab = nplab[i] if tmp_lab.max() == 0: continue print(i) flag += 1 tmp_itkimg = sitk.GetImageFromArray(tmp_img) tmp_itkimg.SetSpacing(spacing[0:3]) tmp_itkimg.SetOrigin(origin[0:3]) tmp_itklab = sitk.GetImageFromArray(tmp_lab) tmp_itklab.SetSpacing(spacing[0:3]) tmp_itklab.SetOrigin(origin[0:3]) re_img = ResampleXYZAxis(tmp_itkimg, space=(target_space[0], target_space[1], spacing[2])) re_lab = ResampleFullImageToRef(tmp_itklab, re_img) sitk.WriteImage(re_img, '%s/%s/%s_%d.nii.gz'%(save_path, patient_name, name, flag)) sitk.WriteImage(re_lab, '%s/%s/%s_gt_%d.nii.gz'%(save_path, patient_name, name, flag)) return flag if __name__ == '__main__': src_path = 'OpenDataset/Training/Labeled' tgt_path = 'dataset/Training' os.chdir(src_path) for name in os.listdir('.'): os.chdir(name) for i in os.listdir('.'): if 'gt' in i: tmp = i.split('_') img_name = tmp[0] + '_' + tmp[1] patient_name = tmp[0] img = sitk.ReadImage('%s.nii.gz'%img_name) lab = sitk.ReadImage('%s_gt.nii.gz'%img_name) flag = ResampleCMRImage(img, lab, tgt_path, patient_name, img_name, (1.2, 1.2)) print(name, 'done', flag) os.chdir('..') ================================================ FILE: dataset/Testing_A.csv ================================================ name A2L1N6 B0H7V0 B3F0V9 B5F8L9 B9H8N8 C4E9I1 C6E0F9 C6U4W8 D1H6U2 D9F5P1 E0J7L9 F0K4T6 G7S6V0 G9N5V9 H1N8S6 I2J6Z6 I5L3S2 J4J8Q3 K6N4N7 P3P9S5 ================================================ FILE: dataset/Testing_B.csv ================================================ name A2H5K9 A4A8V9 A5C2D2 A5D0G0 A6J0Y2 A7E4J0 B1G9J3 B3S2Z4 B4E1K1 B4S1Y2 B5L5Y4 B5T6V0 B7F5P0 C7L8Z8 C8I7P7 D1R0Y5 D3Q0W9 D6E9U8 D8O0W2 D9I8O7 E1L8Y4 E3L8U8 E4I9O7 E4O8P3 E5S7W7 E6J4N8 G1K1V3 G3M5S4 G7N8R7 G7Q2W0 H7K5U5 H8K2K7 I0I2J8 I6T4W8 I7W4Y8 J9L4S2 K5K6N1 K7L2Y6 K9N0W0 L7Y7Z2 L8N7Z0 M4T4V6 M6V2Y0 N9P5Z0 O4T6Y7 O9V8W5 P3R6Y5 P8W4Z0 Q3Q6R8 Y6Y9Z2 ================================================ FILE: dataset/Testing_C.csv ================================================ name A3H5R1 A5H1Q2 A8C5E9 A9F3T5 A9L7Y7 B0L3Y2 B2L0L2 B8P5Q9 C0N8P4 C7M6W0 C8J7L5 C8O0P2 D1H2O9 D2U0V0 D3K5Q2 D5G3W8 D6N7Q8 D7M8P9 D7T3V8 E3F2U7 E3F5U2 E6M6P2 E7L0N6 E9V9Z2 F0I6U8 F1K2S9 F5I1Z8 G8R0Z9 H2M9S1 H3R6S9 H7L8R8 H7P5Z4 I4R8V6 I6P4R0 I8Z0Z6 J6K4V3 K3R0Y7 K7O3Q0 L2V5Z0 L8N7P0 M6M9N1 O4O6U5 P3T5U1 Q1Q3T1 Q5V8W3 R1R6Y8 R2R7Z5 R6V5W3 R8V0Y4 V4W8Z5 ================================================ FILE: dataset/Testing_D.csv ================================================ name A1K2P5 A3P9V7 A4B9O6 A4K8R4 A4R4T0 A5P5W0 A5Q1W8 A6A8H0 A6B7Y4 A7F4G2 B3E2W8 B6I0T4 B9G4U2 C0L7V1 C5L0R0 D1L6T4 D1S5T8 E0J2Z9 E1L7M3 E4H7L4 E5J6L2 E6H0V9 F6J9L9 F9M1R2 G1J5K3 G4I7V2 G8K0M3 G8L0Z0 H6P7T1 I4L4V7 I8N8Y1 J6M5O2 K3P3Y6 K5L4S1 K5M7V5 K7N0R7 L5U7Y4 L6T2T5 L8M2U8 M2P5T8 N2O7U5 N7P3T8 N7W6Z8 N9Q4T8 O5U2U7 O7Q7U3 P8V0Y7 Q4W5Z8 R3V5W7 T2Z1Z9 ================================================ FILE: dataset/Training_A.csv ================================================ name A0S9V9 A1D9Z7 A1E9Q1 A2C0I1 A2N8V0 A3H1O5 A4B5U4 A4J4S4 A4U9V5 A6B5G9 A6D5F9 A7D9L8 A7M7P8 A7O4T6 B0I2Z0 B2C2Z7 B2D9M2 B2G5R2 B3O1S0 B3P3R1 B4O3V3 B8H5H6 B8J7R4 B9E0Q1 C0K1P0 C1K8P5 C2J0K3 C5M4S2 C6J5P1 C8P3S7 D0R0R9 D1J5P6 D1M1S6 D3D4Y5 D3F3O5 D3F9H9 D3O9U9 D4M3Q2 D4N6W6 D8E4F4 D9L1Z3 E0M3U7 E0O0S0 E9H1U4 E9V4Z8 F2H5S1 F3G5K5 G4S9U3 G5P4U3 G7I5V7 H1I3W0 H5N0P0 I7T3U1 J1T9Y1 J6K6P5 J6P5T8 J8R5W2 K4T7Y0 K5L2U3 K5P0Y1 L1Q9V8 L4Q2U3 M1R4S1 M2P1R1 N5S7Y1 N8N9U0 O0S9V7 P0S5Y0 P5R1Y4 Q0U0V5 Q3R9W7 R4Y1Z9 S1S3Z7 T2T9Z9 T9U9W2 ================================================ FILE: dataset/Training_B.csv ================================================ name A1D0Q7 A1O8Z3 A3B7E5 A5E0T8 A6M1Q7 A7G0P5 A8C9U8 A8E1F4 A9C5P4 A9E3G9 A9J5Q7 A9J8W7 B0N3W8 B2D9O2 B2F4K5 B3D0N1 B6D0U7 B9O1Q0 C0S7W0 C1G5Q0 C2L5P7 C2M6P8 C3I2K3 C4R8T7 C4S8W9 D0H9I4 D1L4Q9 D6H6O2 E3T0Z2 E4M2Q7 E4W8Z7 E5E6O8 E5F5V7 E9H2K7 E9L1W5 F0J2R8 F1F3I6 F4K3S1 F5I9Q2 F8N2S1 G0H4J3 G0I6P3 G1N6S7 G2J1M5 G2M7W4 G2O2S6 G4L8Z7 G8N2U5 G9L0O9 H0K3Q4 H1J5W8 H1M5Y6 H1W2Y1 H3U1Y1 H4I2T8 H6I0I6 H7I4J3 H7N4V9 I0J5U3 I2K2Y8 I6N3P3 J4J9W6 J9L6N9 K2S1U6 L1Q1Z5 L5Q6T7 M0P8U8 M4P7Q6 N1P8Q9 N7V9W9 O3R8Y5 P6U0Y0 P9S7W2 Q7V1Y5 W5Z4Z8 ================================================ FILE: dataset/Training_C.csv ================================================ name ================================================ FILE: dataset/Training_D.csv ================================================ name ================================================ FILE: dataset/info.csv ================================================ External code,VendorName,Vendor,Centre,ED,ES A0S9V9,Siemens,A,1,0,9 A1D0Q7,Philips,B,2,0,9 A1D9Z7,Siemens,A,1,22,11 A1E9Q1,Siemens,A,1,0,9 A1K2P5,Canon,D,5,33,11 A1O8Z3,Philips,B,3,23,10 A2C0I1,Siemens,A,1,0,7 A2E3W4,GE,C,4,0,0 A2H5K9,Philips,B,2,29,8 A2L1N6,Siemens,A,1,0,12 A2N8V0,Siemens,A,1,0,9 A3B7E5,Philips,B,2,29,12 A3H1O5,Siemens,A,1,0,12 A3H5R1,GE,C,4,24,6 A3P9V7,Canon,D,5,27,13 A4A8V9,Philips,B,3,0,10 A4B5U4,Siemens,A,1,0,10 A4B9O6,Canon,D,5,0,11 A4J4S4,Siemens,A,1,0,7 A4K8R4,Canon,D,5,24,11 A4R4T0,Canon,D,5,21,8 A4U9V5,Siemens,A,1,0,8 A5C2D2,Philips,B,3,24,9 A5D0G0,Philips,B,2,0,10 A5E0T8,Philips,B,3,24,8 A5H1Q2,GE,C,4,0,9 A5P5W0,Canon,D,5,33,10 A5Q1W8,Canon,D,5,27,10 A6A8H0,Canon,D,5,27,8 A6B5G9,Siemens,A,1,0,11 A6B7Y4,Canon,D,5,1,10 A6D5F9,Siemens,A,1,0,11 A6J0Y2,Philips,B,2,27,13 A6M1Q7,Philips,B,2,29,11 A7D9L8,Siemens,A,1,0,11 A7E4J0,Philips,B,3,1,12 A7F4G2,Canon,D,5,25,10 A7G0P5,Philips,B,2,28,9 A7M7P8,Siemens,A,1,0,9 A7O4T6,Siemens,A,1,0,10 A8C5E9,GE,C,4,24,9 A8C9H8,GE,C,4,0,0 A8C9U8,Philips,B,2,28,9 A8E1F4,Philips,B,3,24,9 A8I1U6,GE,C,4,0,0 A9C5P4,Philips,B,2,29,8 A9E3G9,Philips,B,3,23,8 A9F3T5,GE,C,4,24,8 A9J5Q7,Philips,B,3,24,7 A9J8W7,Philips,B,2,29,10 A9L7Y7,GE,C,4,24,5 B0H7V0,Siemens,A,1,0,9 B0I2Z0,Siemens,A,1,0,8 B0L3Y2,GE,C,4,24,10 B0N3W8,Philips,B,2,28,15 B1G9J3,Philips,B,2,29,10 B1K7U1,GE,C,4,0,0 B2C2Z7,Siemens,A,1,0,8 B2D9M2,Siemens,A,1,0,8 B2D9O2,Philips,B,2,29,13 B2F4K5,Philips,B,2,29,10 B2G5R2,Siemens,A,1,0,7 B2L0L2,GE,C,4,24,8 B3D0N1,Philips,B,3,24,8 B3E2W8,Canon,D,5,23,9 B3F0V9,Siemens,A,1,0,12 B3O1S0,Siemens,A,1,0,8 B3P3R1,Siemens,A,1,0,11 B3S2Z4,Philips,B,2,28,11 B4E1K1,Philips,B,3,0,9 B4I8Z7,GE,C,4,0,0 B4O3V3,Siemens,A,1,0,6 B4S1Y2,Philips,B,2,29,10 B5F8L9,Siemens,A,1,0,8 B5L5Y4,Philips,B,3,24,10 B5T6V0,Philips,B,3,0,10 B6D0U7,Philips,B,2,29,9 B6I0T4,Canon,D,5,33,10 B7F5P0,Philips,B,3,29,12 B8H5H6,Siemens,A,1,24,8 B8J7R4,Siemens,A,1,0,12 B8P5Q9,GE,C,4,23,9 B9E0Q1,Siemens,A,1,0,11 B9G4U2,Canon,D,5,28,8 B9H8N8,Siemens,A,1,0,10 B9O1Q0,Philips,B,2,29,11 C0K1P0,Siemens,A,1,0,9 C0L7V1,Canon,D,5,33,14 C0N8P4,GE,C,4,24,9 C0S7W0,Philips,B,3,24,7 C0W2Y5,GE,C,4,0,0 C1G5Q0,Philips,B,3,24,8 C1K8P5,Siemens,A,1,0,8 C2J0K3,Siemens,A,1,0,9 C2L5P7,Philips,B,2,28,12 C2M6P8,Philips,B,3,24,8 C3I2K3,Philips,B,2,29,11 C4E9I1,Siemens,A,1,0,8 C4K8M1,GE,C,4,0,0 C4R8T7,Philips,B,2,28,8 C4S8W9,Philips,B,3,24,8 C5L0R0,Canon,D,5,29,12 C5M4S2,Siemens,A,1,0,9 C5Q2Y5,GE,C,4,0,0 C6E0F9,Siemens,A,1,0,8 C6J5P1,Siemens,A,1,0,10 C6U4W8,Siemens,A,1,0,9 C7L8Z8,Philips,B,2,29,9 C7M6W0,GE,C,4,24,8 C8I7P7,Philips,B,2,29,11 C8J7L5,GE,C,4,24,7 C8O0P2,GE,C,4,24,8 C8P3S7,Siemens,A,1,0,8 C8V5W8,GE,C,4,0,0 D0E2W0,GE,C,4,0,0 D0H9I4,Philips,B,2,0,10 D0R0R9,Siemens,A,1,0,11 D1H2O9,GE,C,4,24,8 D1H6U2,Siemens,A,1,0,8 D1J5P6,Siemens,A,1,0,13 D1L4Q9,Philips,B,2,29,10 D1L6T4,Canon,D,5,27,10 D1M1S6,Siemens,A,1,0,10 D1R0Y5,Philips,B,3,24,9 D1S5T8,Canon,D,5,23,14 D2U0V0,GE,C,4,24,10 D3D4Y5,Siemens,A,1,0,9 D3F3O5,Siemens,A,1,24,10 D3F9H9,Siemens,A,1,24,8 D3K5Q2,GE,C,4,11,0 D3O9U9,Siemens,A,1,0,10 D3Q0W9,Philips,B,3,24,10 D4M3Q2,Siemens,A,1,0,9 D4N6W6,Siemens,A,1,0,10 D5G3W8,GE,C,4,23,8 D6E9U8,Philips,B,2,0,12 D6H6O2,Philips,B,2,29,9 D6N7Q8,GE,C,4,24,9 D7M8P9,GE,C,4,24,8 D7T3V8,GE,C,4,24,7 D8E4F4,Siemens,A,1,0,10 D8O0W2,Philips,B,3,24,8 D9F5P1,Siemens,A,1,0,10 D9I8O7,Philips,B,3,29,11 D9L1Z3,Siemens,A,1,0,12 E0J2Z9,Canon,D,5,25,11 E0J7L9,Siemens,A,1,0,9 E0M3U7,Siemens,A,1,0,9 E0O0S0,Siemens,A,1,0,11 E1L7M3,Canon,D,5,1,12 E1L8Y4,Philips,B,3,24,10 E3F2U7,GE,C,4,0,7 E3F5U2,GE,C,4,24,8 E3I4V1,GE,C,4,0,0 E3L8U8,Philips,B,3,29,9 E3T0Z2,Philips,B,2,29,10 E4H7L4,Canon,D,5,28,10 E4I9O7,Philips,B,3,29,10 E4M2Q7,Philips,B,3,0,8 E4O8P3,Philips,B,2,29,12 E4W8Z7,Philips,B,2,29,10 E5E6O8,Philips,B,2,29,10 E5F5V7,Philips,B,3,0,9 E5J6L2,Canon,D,5,29,11 E5S7W7,Philips,B,2,28,12 E6H0V9,Canon,D,5,33,12 E6J4N8,Philips,B,2,29,11 E6M6P2,GE,C,4,24,7 E7L0N6,GE,C,4,22,8 E9H1U4,Siemens,A,1,0,10 E9H2K7,Philips,B,2,29,11 E9L1W5,Philips,B,3,24,9 E9L4N2,GE,C,4,0,0 E9V4Z8,Siemens,A,1,0,8 E9V9Z2,GE,C,4,28,10 F0I6U8,GE,C,4,10,23 F0J2R8,Philips,B,2,29,9 F0K4T6,Siemens,A,1,0,8 F1F3I6,Philips,B,2,28,8 F1K2S9,GE,C,4,23,7 F2H5S1,Siemens,A,1,24,10 F3G5K5,Siemens,A,1,24,12 F3L0M1,GE,C,4,0,0 F4K3S1,Philips,B,3,24,8 F5I1Z8,GE,C,4,23,8 F5I9Q2,Philips,B,2,29,11 F6J9L9,Canon,D,5,29,10 F8N2S1,Philips,B,3,24,8 F9M1R2,Canon,D,5,29,11 G0H4J3,Philips,B,2,29,10 G0I6P3,Philips,B,3,29,12 G0I7Z6,GE,C,4,0,0 G1J5K3,Canon,D,5,24,13 G1K1V3,Philips,B,2,29,11 G1N6S7,Philips,B,2,0,12 G2J1M5,Philips,B,2,29,9 G2M7W4,Philips,B,3,24,9 G2O2S6,Philips,B,2,29,10 G3M5S4,Philips,B,3,29,10 G4I7V2,Canon,D,5,33,9 G4K8P3,GE,C,4,0,0 G4L8Z7,Philips,B,2,29,8 G4S9U3,Siemens,A,1,0,11 G4U3U8,GE,C,4,0,0 G5P4U3,Siemens,A,1,0,9 G6T0Z6,GE,C,4,0,0 G7I5V7,Siemens,A,1,0,10 G7N8R7,Philips,B,2,28,9 G7Q2W0,Philips,B,2,0,8 G7S6V0,Siemens,A,1,0,8 G8K0M3,Canon,D,5,25,10 G8L0Z0,Canon,D,5,3,10 G8N2U5,Philips,B,3,23,9 G8R0Z9,GE,C,4,21,8 G9L0O9,Philips,B,2,29,10 G9N5V9,Siemens,A,1,0,11 H0K3Q4,Philips,B,3,24,9 H1I3W0,Siemens,A,1,0,9 H1J5W8,Philips,B,2,3,15 H1M5Y6,Philips,B,2,29,8 H1N7P7,GE,C,4,0,0 H1N8S6,Siemens,A,1,0,10 H1W2Y1,Philips,B,2,29,10 H2M9S1,GE,C,4,23,8 H3R6S9,GE,C,4,24,9 H3U1Y1,Philips,B,3,24,8 H4I2T8,Philips,B,3,24,8 H5N0P0,Siemens,A,1,0,8 H6I0I6,Philips,B,2,28,10 H6P7T1,Canon,D,5,25,12 H7I4J3,Philips,B,3,24,8 H7K5U5,Philips,B,3,24,8 H7L8R8,GE,C,4,23,9 H7N4V9,Philips,B,2,0,12 H7P5Z4,GE,C,4,24,9 H8K2K7,Philips,B,3,23,7 H9J6L5,GE,C,4,0,0 I0I2J8,Philips,B,2,21,7 I0J5U3,Philips,B,2,28,10 I2J6Z6,Siemens,A,1,0,7 I2K2Y8,Philips,B,2,29,8 I4J8P4,GE,C,4,0,0 I4L4V7,Canon,D,5,22,9 I4R8V6,GE,C,4,24,8 I5L3S2,Siemens,A,1,0,10 I6N3P3,Philips,B,2,29,10 I6P4R0,GE,C,4,18,11 I6T4W8,Philips,B,2,28,12 I7T3U1,Siemens,A,1,0,13 I7W4Y8,Philips,B,2,27,10 I8N8Y1,Canon,D,5,33,14 I8Z0Z6,GE,C,4,24,8 J1T9Y1,Siemens,A,1,0,12 J2S5T1,GE,C,4,0,0 J4J8Q3,Siemens,A,1,0,8 J4J9W6,Philips,B,2,29,11 J6K4V3,GE,C,4,21,3 J6K6P5,Siemens,A,1,0,9 J6M5O2,Canon,D,5,33,11 J6P5T8,Siemens,A,1,0,9 J8R5W2,Siemens,A,1,0,9 J9L4S2,Philips,B,3,29,13 J9L6N9,Philips,B,2,29,9 K2S1U6,Philips,B,2,29,12 K3P3Y6,Canon,D,5,32,11 K3R0Y7,GE,C,4,28,9 K4T7Y0,Siemens,A,1,24,12 K5K6N1,Philips,B,3,24,9 K5L2U3,Siemens,A,1,0,7 K5L4S1,Canon,D,5,27,10 K5M7V5,Canon,D,5,23,10 K5P0Y1,Siemens,A,1,0,9 K6N4N7,Siemens,A,1,0,8 K7L2Y6,Philips,B,2,29,10 K7N0R7,Canon,D,5,29,9 K7O3Q0,GE,C,4,23,10 K9N0W0,Philips,B,2,29,12 L1Q1Z5,Philips,B,2,0,15 L1Q9V8,Siemens,A,1,0,10 L2V5Z0,GE,C,4,3,17 L4Q2U3,Siemens,A,1,0,8 L5Q6T7,Philips,B,2,28,11 L5U7Y4,Canon,D,5,33,10 L6T2T5,Canon,D,5,1,18 L7Y7Z2,Philips,B,2,0,11 L8M2U8,Canon,D,5,0,11 L8N7P0,GE,C,4,23,9 L8N7Z0,Philips,B,3,29,9 M0P8U8,Philips,B,3,0,9 M1R4S1,Siemens,A,1,0,9 M2P1R1,Siemens,A,1,0,10 M2P5T8,Canon,D,5,0,11 M4P7Q6,Philips,B,2,2,13 M4T4V6,Philips,B,2,29,12 M6M9N1,GE,C,4,24,11 M6V2Y0,Philips,B,3,0,9 M8N3Z3,GE,C,4,0,0 N1P8Q9,Philips,B,2,29,10 N1S7Z2,GE,C,4,0,0 N2O7U5,Canon,D,5,18,7 N5S7Y1,Siemens,A,1,0,9 N7P3T8,Canon,D,5,25,10 N7V9W9,Philips,B,3,24,7 N7W6Z8,Canon,D,5,23,11 N8N9U0,Siemens,A,1,0,11 N9P5Z0,Philips,B,3,0,13 N9Q4T8,Canon,D,5,32,10 O0S9V7,Siemens,A,1,0,9 O1O9Y6,GE,C,4,0,0 O3R8Y5,Philips,B,2,0,8 O4O6U5,GE,C,4,24,8 O4T6Y7,Philips,B,2,0,13 O5U2U7,Canon,D,5,1,9 O7Q7U3,Canon,D,5,29,11 O9V8W5,Philips,B,2,29,12 P0S5Y0,Siemens,A,1,0,10 P3P9S5,Siemens,A,1,0,10 P3R6Y5,Philips,B,3,24,9 P3T5U1,GE,C,4,27,9 P5R1Y4,Siemens,A,1,0,9 P6U0Y0,Philips,B,2,29,10 P8V0Y7,Canon,D,5,27,9 P8W4Z0,Philips,B,3,0,12 P9S7W2,Philips,B,3,23,10 Q0Q1Y4,GE,C,4,0,0 Q0U0V5,Siemens,A,1,0,10 Q1Q3T1,GE,C,4,28,10 Q3Q6R8,Philips,B,3,24,9 Q3R9W7,Siemens,A,1,0,10 Q4W5Z8,Canon,D,5,33,10 Q5V8W3,GE,C,4,24,10 Q7V1Y5,Philips,B,2,0,11 R1R6Y8,GE,C,4,23,10 R2R7Z5,GE,C,4,23,8 R3V5W7,Canon,D,5,27,13 R4Y1Z9,Siemens,A,1,24,7 R6V5W3,GE,C,4,28,10 R8V0Y4,GE,C,4,24,7 S1S3Z7,Siemens,A,1,0,9 T2T9Z9,Siemens,A,1,0,13 T2Z1Z9,Canon,D,5,29,9 T9U9W2,Siemens,A,1,0,10 V4W8Z5,GE,C,4,19,9 W5Z4Z8,Philips,B,2,29,11 Y6Y9Z2,Philips,B,3,29,9 ================================================ FILE: dataset_domain.py ================================================ import os import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import Dataset, DataLoader import pandas as pd import SimpleITK as sitk from skimage.measure import label, regionprops import math import pdb class CMRDataset(Dataset): def __init__(self, dataset_dir, mode='train', domain='A', crop_size=256, scale=0.1, rotate=10, debug=False): self.mode = mode self.dataset_dir = dataset_dir self.crop_size = crop_size self.scale = scale self.rotate = rotate if self.mode == 'train': pre_face = 'Training' if 'C' in domain or 'D' in domain: print('No domain C or D in Training set') raise StandardError elif self.mode == 'test': pre_face = 'Testing' else: print('Wrong mode') raise StandardError if debug: # validation set is the smallest, need the shortest time for load data. pre_face = 'Testing' path = self.dataset_dir + pre_face + '/' print('start loading data') name_list = [] if 'A' in domain: df = pd.read_csv(self.dataset_dir+pre_face+'_A.csv') name_list += np.array(df['name']).tolist() if 'B' in domain: df = pd.read_csv(self.dataset_dir+pre_face+'_B.csv') name_list += np.array(df['name']).tolist() if 'C' in domain: df = pd.read_csv(self.dataset_dir+pre_face+'_C.csv') name_list += np.array(df['name']).tolist() if 'D' in domain: df = pd.read_csv(self.dataset_dir+pre_face+'_D.csv') name_list += np.array(df['name']).tolist() img_list = [] lab_list = [] spacing_list = [] for name in name_list: for name_idx in os.listdir(path+name): if 'gt' in name_idx: continue else: idx = name_idx.split('_')[2].split('.')[0] itk_img = sitk.ReadImage(path+name+'/%s_sa_%s.nii.gz'%(name, idx)) itk_lab = sitk.ReadImage(path+name+'/%s_sa_gt_%s.nii.gz'%(name, idx)) spacing = np.array(itk_lab.GetSpacing()).tolist() spacing_list.append(spacing[::-1]) assert itk_img.GetSize() == itk_lab.GetSize() img, lab = self.preprocess(itk_img, itk_lab) img_list.append(img) lab_list.append(lab) self.img_slice_list = [] self.lab_slice_list = [] if self.mode == 'train': for i in range(len(img_list)): tmp_img = img_list[i] tmp_lab = lab_list[i] z, x, y = tmp_img.shape for j in range(z): self.img_slice_list.append(tmp_img[j]) self.lab_slice_list.append(tmp_lab[j]) else: self.img_slice_list = img_list self.lab_slice_list = lab_list self.spacing_list = spacing_list print('load done, length of dataset:', len(self.img_slice_list)) def __len__(self): return len(self.img_slice_list) def preprocess(self, itk_img, itk_lab): img = sitk.GetArrayFromImage(itk_img) lab = sitk.GetArrayFromImage(itk_lab) max98 = np.percentile(img, 98) img = np.clip(img, 0, max98) z, y, x = img.shape if x < self.crop_size: diff = (self.crop_size + 10 - x) // 2 img = np.pad(img, ((0,0), (0,0), (diff, diff))) lab = np.pad(lab, ((0,0), (0,0), (diff,diff))) if y < self.crop_size: diff = (self.crop_size + 10 -y) // 2 img = np.pad(img, ((0,0), (diff, diff), (0,0))) lab = np.pad(lab, ((0,0), (diff, diff), (0,0))) img = img / max98 tensor_img = torch.from_numpy(img).float() tensor_lab = torch.from_numpy(lab).long() return tensor_img, tensor_lab def __getitem__(self, idx): tensor_image = self.img_slice_list[idx] tensor_label = self.lab_slice_list[idx] if self.mode == 'train': tensor_image = tensor_image.unsqueeze(0).unsqueeze(0) tensor_label = tensor_label.unsqueeze(0).unsqueeze(0) # Gaussian Noise tensor_image += torch.randn(tensor_image.shape) * 0.02 # Additive brightness rnd_bn = np.random.normal(0, 0.7)#0.03 tensor_image += rnd_bn # gamma minm = tensor_image.min() rng = tensor_image.max() - minm gamma = np.random.uniform(0.5, 1.6) tensor_image = torch.pow((tensor_image-minm)/rng, gamma)*rng + minm tensor_image, tensor_label = self.random_zoom_rotate(tensor_image, tensor_label) tensor_image, tensor_label = self.randcrop(tensor_image, tensor_label) else: tensor_image, tensor_label = self.center_crop(tensor_image, tensor_label) assert tensor_image.shape == tensor_label.shape if self.mode == 'train': return tensor_image, tensor_label else: return tensor_image, tensor_label, np.array(self.spacing_list[idx]) def randcrop(self, img, label): _, _, H, W = img.shape diff_H = H - self.crop_size diff_W = W - self.crop_size rand_x = np.random.randint(0, diff_H) rand_y = np.random.randint(0, diff_W) croped_img = img[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size] croped_lab = label[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size] return croped_img, croped_lab def center_crop(self, img, label): D, H, W = img.shape diff_H = H - self.crop_size diff_W = W - self.crop_size rand_x = diff_H // 2 rand_y = diff_W // 2 croped_img = img[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size] croped_lab = label[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size] return croped_img, croped_lab def random_zoom_rotate(self, img, label): scale_x = np.random.random() * 2 * self.scale + (1 - self.scale) scale_y = np.random.random() * 2 * self.scale + (1 - self.scale) theta_scale = torch.tensor([[scale_x, 0, 0], [0, scale_y, 0], [0, 0, 1]]).float() angle = (float(np.random.randint(-self.rotate, self.rotate)) / 180.) * math.pi theta_rotate = torch.tensor( [ [math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], ]).float() theta_rotate = theta_rotate.unsqueeze(0) grid = F.affine_grid(theta_rotate, img.size()) img = F.grid_sample(img, grid, mode='bilinear') label = F.grid_sample(label.float(), grid, mode='nearest').long() return img, label ================================================ FILE: losses.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pdb class DiceLoss(nn.Module): def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True): super(DiceLoss, self).__init__() self.alpha = alpha self.beta = beta self.size_average = size_average self.reduce = reduce def forward(self, preds, targets): N = preds.size(0) C = preds.size(1) P = F.softmax(preds, dim=1) smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001) class_mask = torch.zeros(preds.shape).to(preds.device) class_mask.scatter_(1, targets, 1.) ones = torch.ones(preds.shape).to(preds.device) P_ = ones - P class_mask_ = ones - class_mask TP = P * class_mask FP = P * class_mask_ FN = P_ * class_mask smooth = smooth.to(preds.device) self.alpha = FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) / ((FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) + FN.transpose(0, 1).reshape(C, -1).sum(dim=(1))) + smooth) self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8) #print('alpha:', self.alpha) self.beta = 1 - self.alpha num = torch.sum(TP.transpose(0, 1).reshape(C, -1), dim=(1)).float() den = num + self.alpha * torch.sum(FP.transpose(0, 1).reshape(C, -1), dim=(1)).float() + self.beta * torch.sum(FN.transpose(0, 1).reshape(C, -1), dim=(1)).float() dice = num / (den + smooth) if not self.reduce: loss = torch.ones(C).to(dice.device) - dice return loss loss = 1 - dice loss = loss.sum() if self.size_average: loss /= C return loss class FocalLoss(nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=True): super(FocalLoss, self).__init__() if alpha is None: self.alpha = torch.ones(class_num) else: self.alpha = alpha self.gamma = gamma self.size_average = size_average def forward(self, preds, targets): N = preds.size(0) C = preds.size(1) targets = targets.unsqueeze(1) P = F.softmax(preds, dim=1) log_P = F.log_softmax(preds, dim=1) class_mask = torch.zeros(preds.shape).to(preds.device) class_mask.scatter_(1, targets, 1.) if targets.size(1) == 1: # squeeze the chaneel for target targets = targets.squeeze(1) alpha = self.alpha[targets.data].to(preds.device) probs = (P * class_mask).sum(1) log_probs = (log_P * class_mask).sum(1) batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss if __name__ == '__main__': DL = DiceLoss() FL = FocalLoss(10) pred = torch.randn(2, 10, 128, 128) target = torch.zeros((2, 1, 128, 128)).long() dl_loss = DL(pred, target) fl_loss = FL(pred, target) print('2D:', dl_loss.item(), fl_loss.item()) pred = torch.randn(2, 10, 64, 128, 128) target = torch.zeros(2, 1, 64, 128, 128).long() dl_loss = DL(pred, target) fl_loss = FL(pred, target) print('3D:', dl_loss.item(), fl_loss.item()) ================================================ FILE: model/__init__.py ================================================ ================================================ FILE: model/conv_trans_utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import pdb def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) class depthwise_separable_conv(nn.Module): def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False): super().__init__() self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride) self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias) def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out class Mlp(nn.Module): def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU, drop=0.): super().__init__() out_ch = out_ch or in_ch hid_ch = hid_ch or in_ch self.fc1 = nn.Conv2d(in_ch, hid_ch, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hid_ch, out_ch, kernel_size=1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1): super().__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(inplanes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or inplanes != planes: self.shortcut = nn.Sequential( nn.BatchNorm2d(inplanes), self.relu, nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) ) def forward(self, x): residue = x out = self.bn1(x) out = self.relu(out) out = self.conv1(out) out = self.bn2(out) out = self.relu(out) out = self.conv2(out) out += self.shortcut(residue) return out class BasicTransBlock(nn.Module): def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() self.bn1 = nn.BatchNorm2d(in_ch) self.attn = LinearAttention(in_ch, heads=heads, dim_head=in_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) self.bn2 = nn.BatchNorm2d(in_ch) self.relu = nn.ReLU(inplace=True) self.mlp = nn.Conv2d(in_ch, in_ch, kernel_size=1, bias=False) # conv1x1 has not difference with mlp in performance def forward(self, x): out = self.bn1(x) out, q_k_attn = self.attn(out) out = out + x residue = out out = self.bn2(out) out = self.relu(out) out = self.mlp(out) out += residue return out class BasicTransDecoderBlock(nn.Module): def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() self.bn_l = nn.BatchNorm2d(in_ch) self.bn_h = nn.BatchNorm2d(out_ch) self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1) self.attn = LinearAttentionDecoder(in_ch, out_ch, heads=heads, dim_head=out_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) self.bn2 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) self.mlp = nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False) def forward(self, x1, x2): residue = F.interpolate(self.conv_ch(x1), size=x2.shape[-2:], mode='bilinear', align_corners=True) #x1: low-res, x2: high-res x1 = self.bn_l(x1) x2 = self.bn_h(x2) out, q_k_attn = self.attn(x2, x1) out = out + residue residue = out out = self.bn2(out) out = self.relu(out) out = self.mlp(out) out += residue return out ######################################################################## # Transformer components class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** (-0.5) self.dim_head = dim_head self.reduce_size = reduce_size self.projection = projection self.rel_pos = rel_pos # depthwise conv is slightly better than conv1x1 #self.to_qkv = nn.Conv2d(dim, self.inner_dim*3, kernel_size=1, stride=1, padding=0, bias=True) #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True) self.to_qkv = depthwise_separable_conv(dim, self.inner_dim*3) self.to_out = depthwise_separable_conv(self.inner_dim, dim) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) if self.rel_pos: # 2D input-independent relative position encoding is a little bit better than # 1D input-denpendent counterpart self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size) #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size) def forward(self, x): B, C, H, W = x.shape #B, inner_dim, H, W qkv = self.to_qkv(x) q, k, v = qkv.chunk(3, dim=1) if self.projection == 'interp' and H != self.reduce_size: k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v)) elif self.projection == 'maxpool' and H != self.reduce_size: k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v)) q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=H, w=W) k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v)) q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k) if self.rel_pos: relative_position_bias = self.relative_position_encoding(H, W) q_k_attn += relative_position_bias #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, H, W, self.dim_head) #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w q_k_attn *= self.scale q_k_attn = F.softmax(q_k_attn, dim=-1) q_k_attn = self.attn_drop(q_k_attn) out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v) out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads) out = self.to_out(out) out = self.proj_drop(out) return out, q_k_attn class LinearAttentionDecoder(nn.Module): def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** (-0.5) self.dim_head = dim_head self.reduce_size = reduce_size self.projection = projection self.rel_pos = rel_pos # depthwise conv is slightly better than conv1x1 #self.to_kv = nn.Conv2d(dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=True) #self.to_q = nn.Conv2d(dim, self.inner_dim, kernel_size=1, stride=1, padding=0, bias=True) #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True) self.to_kv = depthwise_separable_conv(in_dim, self.inner_dim*2) self.to_q = depthwise_separable_conv(out_dim, self.inner_dim) self.to_out = depthwise_separable_conv(self.inner_dim, out_dim) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) if self.rel_pos: self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size) #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size) def forward(self, q, x): B, C, H, W = x.shape # low-res feature shape BH, CH, HH, WH = q.shape # high-res feature shape k, v = self.to_kv(x).chunk(2, dim=1) #B, inner_dim, H, W q = self.to_q(q) #BH, inner_dim, HH, WH if self.projection == 'interp' and H != self.reduce_size: k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v)) elif self.projection == 'maxpool' and H != self.reduce_size: k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v)) q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=HH, w=WH) k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v)) q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k) if self.rel_pos: relative_position_bias = self.relative_position_encoding(HH, WH) q_k_attn += relative_position_bias #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, HH, WH, self.dim_head) #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w q_k_attn *= self.scale q_k_attn = F.softmax(q_k_attn, dim=-1) q_k_attn = self.attn_drop(q_k_attn) out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v) out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=HH, w=WH, dim_head=self.dim_head, heads=self.heads) out = self.to_out(out) out = self.proj_drop(out) return out, q_k_attn class RelativePositionEmbedding(nn.Module): # input-dependent relative position def __init__(self, dim, shape): super().__init__() self.dim = dim self.shape = shape self.key_rel_w = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02) self.key_rel_h = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02) coords = torch.arange(self.shape) relative_coords = coords[None, :] - coords[:, None] # h, h relative_coords += self.shape - 1 # shift to start from 0 self.register_buffer('relative_position_index', relative_coords) def forward(self, q, Nh, H, W, dim_head): # q: B, Nh, HW, dim B, _, _, dim = q.shape # q: B, Nh, H, W, dim_head q = rearrange(q, 'b heads (h w) dim_head -> b heads h w dim_head', b=B, dim_head=dim_head, heads=Nh, h=H, w=W) rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, 'w') rel_logits_h = self.relative_logits_1d(q.permute(0, 1, 3, 2, 4), self.key_rel_h, 'h') return rel_logits_w, rel_logits_h def relative_logits_1d(self, q, rel_k, case): B, Nh, H, W, dim = q.shape rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) # B, Nh, H, W, 2*shape-1 if W != self.shape: # self_relative_position_index origin shape: w, w # after repeat: W, w relative_index= torch.repeat_interleave(self.relative_position_index, W//self.shape, dim=0) # W, shape relative_index = relative_index.view(1, 1, 1, W, self.shape) relative_index = relative_index.repeat(B, Nh, H, 1, 1) rel_logits = torch.gather(rel_logits, 4, relative_index) # B, Nh, H, W, shape rel_logits = rel_logits.unsqueeze(3) rel_logits = rel_logits.repeat(1, 1, 1, self.shape, 1, 1) if case == 'w': rel_logits = rearrange(rel_logits, 'b heads H h W w -> b heads (H W) (h w)') elif case == 'h': rel_logits = rearrange(rel_logits, 'b heads W w H h -> b heads (H W) (h w)') return rel_logits class RelativePositionBias(nn.Module): # input-independent relative position attention # As the number of parameters is smaller, so use 2D here # Borrowed some code from SwinTransformer: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py def __init__(self, num_heads, h, w): super().__init__() self.num_heads = num_heads self.h = h self.w = w self.relative_position_bias_table = nn.Parameter( torch.randn((2*h-1) * (2*w-1), num_heads)*0.02) coords_h = torch.arange(self.h) coords_w = torch.arange(self.w) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w coords_flatten = torch.flatten(coords, 1) # 2, hw relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.h - 1 relative_coords[:, :, 1] += self.w - 1 relative_coords[:, :, 0] *= 2 * self.h - 1 relative_position_index = relative_coords.sum(-1) # hw, hw self.register_buffer("relative_position_index", relative_position_index) def forward(self, H, W): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.h, self.w, self.h*self.w, -1) #h, w, hw, nH relative_position_bias_expand_h = torch.repeat_interleave(relative_position_bias, H//self.h, dim=0) relative_position_bias_expanded = torch.repeat_interleave(relative_position_bias_expand_h, W//self.w, dim=1) #HW, hw, nH relative_position_bias_expanded = relative_position_bias_expanded.view(H*W, self.h*self.w, self.num_heads).permute(2, 0, 1).contiguous().unsqueeze(0) return relative_position_bias_expanded ########################################################################### # Unet Transformer building block class down_block_trans(nn.Module): def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool=True, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() block_list = [] if bottleneck: block = BottleneckBlock else: block = BasicBlock attn_block = BasicTransBlock if maxpool: block_list.append(nn.MaxPool2d(2)) block_list.append(block(in_ch, out_ch, stride=1)) else: block_list.append(block(in_ch, out_ch, stride=2)) assert num_block > 0 for i in range(num_block): block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)) self.blocks = nn.Sequential(*block_list) def forward(self, x): out = self.blocks(x) return out class up_block_trans(nn.Module): def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() self.attn_decoder = BasicTransDecoderBlock(in_ch, out_ch, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) if bottleneck: block = BottleneckBlock else: block = BasicBlock attn_block = BasicTransBlock block_list = [] for i in range(num_block): block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)) block_list.append(block(2*out_ch, out_ch, stride=1)) self.blocks = nn.Sequential(*block_list) def forward(self, x1, x2): # x1: low-res feature, x2: high-res feature out = self.attn_decoder(x1, x2) out = torch.cat([out, x2], dim=1) out = self.blocks(out) return out class block_trans(nn.Module): def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True): super().__init__() block_list = [] attn_block = BasicTransBlock assert num_block > 0 for i in range(num_block): block_list.append(attn_block(in_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)) self.blocks = nn.Sequential(*block_list) def forward(self, x): out = self.blocks(x) return out ================================================ FILE: model/resnet_utnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .unet_utils import up_block from .transunet import ResNetV2 from .conv_trans_utils import block_trans import pdb class ResNet_UTNet(nn.Module): def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', num_blocks=[1,2,4], projection='interp', num_heads=[4,4,4], attn_drop=0., proj_drop=0., rel_pos=True, block_units=(3,4,9), width_factor=1): super().__init__() self.resnet = ResNetV2(block_units, width_factor) if '0' in block_list: self.trans_0 = block_trans(64, num_blocks[-4], 64//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.trans_0 = nn.Identity() if '1' in block_list: self.trans_1 = block_trans(256, num_blocks[-3], 256//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.trans_1 = nn.Identity() if '2' in block_list: self.trans_2 = block_trans(512, num_blocks[-2], 512//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.trans_2 = nn.Identity() if '3' in block_list: self.trans_3 = block_trans(1024, num_blocks[-1], 1024//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.trans_3 = nn.Identity() self.up1 = up_block(1024, 512, scale=(2,2), num_block=1) self.up2 = up_block(512, 256, scale=(2,2), num_block=1) self.up3 = up_block(256, 64, scale=(2,2), num_block=1) self.up4 = nn.UpsamplingBilinear2d(scale_factor=2) self.output = nn.Conv2d(64, num_class, kernel_size=3, padding=1, bias=True) def forward(self, x): if x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) x, features = self.resnet(x) out3 = self.trans_3(x) out2 = self.trans_2(features[0]) out1 = self.trans_1(features[1]) out0 = self.trans_0(features[2]) out = self.up1(out3, out2) out = self.up2(out, out1) out = self.up3(out, out0) out = self.up4(out) out = self.output(out) return out ================================================ FILE: model/swin_unet.py ================================================ # code is borrowed from the original repo and fit into our training framework # https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912 from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from einops import rearrange from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import copy import logging import math from os.path import join as pjoin import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x def extra_repr(self) -> str: return f"input_resolution={self.input_resolution}, dim={self.dim}" def flops(self): H, W = self.input_resolution flops = H * W * self.dim flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim return flops class PatchExpand(nn.Module): def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) x = x.view(B,-1,C//4) x= self.norm(x) return x class FinalPatchExpand_X4(nn.Module): def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.dim_scale = dim_scale self.expand = nn.Linear(dim, 16*dim, bias=False) self.output_dim = dim self.norm = norm_layer(self.output_dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) x = x.view(B,-1,self.output_dim) x= self.norm(x) return x class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" def flops(self): flops = 0 for blk in self.blocks: flops += blk.flops() if self.downsample is not None: flops += self.downsample.flops() return flops class BasicLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if upsample is not None: self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) else: self.upsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.upsample is not None: x = self.upsample(x) return x class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x def flops(self): Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwinTransformerSys(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, final_upsample="expand_first", **kwargs): super().__init__() print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, depths_decoder,drop_path_rate,num_classes)) self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.num_features_up = int(embed_dim * 2) self.mlp_ratio = mlp_ratio self.final_upsample = final_upsample # split image into non-overlapping patches self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build encoder and bottleneck layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) # build decoder layers self.layers_up = nn.ModuleList() self.concat_back_dim = nn.ModuleList() for i_layer in range(self.num_layers): concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() if i_layer ==0 : layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) else: layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), depth=depths[(self.num_layers-1-i_layer)], num_heads=num_heads[(self.num_layers-1-i_layer)], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], norm_layer=norm_layer, upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers_up.append(layer_up) self.concat_back_dim.append(concat_linear) self.norm = norm_layer(self.num_features) self.norm_up= norm_layer(self.embed_dim) if self.final_upsample == "expand_first": print("---final upsample expand_first---") self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} #Encoder and Bottleneck def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) x_downsample = [] for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B L C return x, x_downsample #Dencoder and Skip connection def forward_up_features(self, x, x_downsample): for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: x = torch.cat([x,x_downsample[3-inx]],-1) x = self.concat_back_dim[inx](x) x = layer_up(x) x = self.norm_up(x) # B L C return x def up_x4(self, x): H, W = self.patches_resolution B, L, C = x.shape assert L == H*W, "input features has wrong size" if self.final_upsample=="expand_first": x = self.up(x) x = x.view(B,4*H,4*W,-1) x = x.permute(0,3,1,2) #B,C,H,W x = self.output(x) return x def forward(self, x): x, x_downsample = self.forward_features(x) x = self.forward_up_features(x,x_downsample) x = self.up_x4(x) return x def flops(self): flops = 0 flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) flops += self.num_features * self.num_classes return flops logger = logging.getLogger(__name__) class SwinUnet_config(): def __init__(self): self.patch_size = 4 self.in_chans = 3 self.num_classes = 4 self.embed_dim = 96 self.depths = [2, 2, 6, 2] self.num_heads = [3, 6, 12, 24] self.window_size = 7 self.mlp_ratio = 4. self.qkv_bias = True self.qk_scale = None self.drop_rate = 0. self.drop_path_rate = 0.1 self.ape = False self.patch_norm = True self.use_checkpoint = False class SwinUnet(nn.Module): def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): super(SwinUnet, self).__init__() self.num_classes = num_classes self.zero_head = zero_head self.config = config self.swin_unet = SwinTransformerSys(img_size=img_size, patch_size=config.patch_size, in_chans=config.in_chans, num_classes=self.num_classes, embed_dim=config.embed_dim, depths=config.depths, num_heads=config.num_heads, window_size=config.window_size, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, qk_scale=config.qk_scale, drop_rate=config.drop_rate, drop_path_rate=config.drop_path_rate, ape=config.ape, patch_norm=config.patch_norm, use_checkpoint=config.use_checkpoint) def forward(self, x): if x.size()[1] == 1: x = x.repeat(1,3,1,1) logits = self.swin_unet(x) return logits def load_from(self, pretrained_path): if pretrained_path is not None: print("pretrained_path:{}".format(pretrained_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_dict = torch.load(pretrained_path, map_location=device) if "model" not in pretrained_dict: print("---start load pretrained modle by splitting---") pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} for k in list(pretrained_dict.keys()): if "output" in k: print("delete key:{}".format(k)) del pretrained_dict[k] msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) # print(msg) return pretrained_dict = pretrained_dict['model'] print("---start load pretrained modle of swin encoder---") model_dict = self.swin_unet.state_dict() full_dict = copy.deepcopy(pretrained_dict) for k, v in pretrained_dict.items(): if "layers." in k: current_layer_num = 3-int(k[7:8]) current_k = "layers_up." + str(current_layer_num) + k[8:] full_dict.update({current_k:v}) for k in list(full_dict.keys()): if k in model_dict: if full_dict[k].shape != model_dict[k].shape: print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) del full_dict[k] msg = self.swin_unet.load_state_dict(full_dict, strict=False) # print(msg) else: print("none pretrain") ================================================ FILE: model/transunet.py ================================================ # The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660ecda1d59 # coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import ml_collections import copy import logging import math from collections import OrderedDict from os.path import join as pjoin import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage logger = logging.getLogger(__name__) ATTENTION_Q = "MultiHeadDotProductAttention_1/query" ATTENTION_K = "MultiHeadDotProductAttention_1/key" ATTENTION_V = "MultiHeadDotProductAttention_1/value" ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" FC_0 = "MlpBlock_3/Dense_0" FC_1 = "MlpBlock_3/Dense_1" ATTENTION_NORM = "LayerNorm_0" MLP_NORM = "LayerNorm_2" def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} class Attention(nn.Module): def __init__(self, config, vis): super(Attention, self).__init__() self.vis = vis self.num_attention_heads = config.transformer["num_heads"] self.attention_head_size = int(config.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = Linear(config.hidden_size, self.all_head_size) self.key = Linear(config.hidden_size, self.all_head_size) self.value = Linear(config.hidden_size, self.all_head_size) self.out = Linear(config.hidden_size, config.hidden_size) self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) self.softmax = Softmax(dim=-1) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) weights = attention_probs if self.vis else None attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) return attention_output, weights class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) self.act_fn = ACT2FN["gelu"] self.dropout = Dropout(config.transformer["dropout_rate"]) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6) def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings. """ def __init__(self, config, img_size, in_channels=3): super(Embeddings, self).__init__() self.hybrid = None self.config = config img_size = _pair(img_size) if config.patches.get("grid") is not None: # ResNet grid_size = config.patches["grid"] patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) self.hybrid = True else: patch_size = _pair(config.patches["size"]) n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) self.hybrid = False if self.hybrid: self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) in_channels = self.hybrid_model.width * 16 self.patch_embeddings = Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) self.dropout = Dropout(config.transformer["dropout_rate"]) def forward(self, x): if self.hybrid: x, features = self.hybrid_model(x) else: features = None x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) x = x.flatten(2) x = x.transpose(-1, -2) # (B, n_patches, hidden) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings, features class Block(nn.Module): def __init__(self, config, vis): super(Block, self).__init__() self.hidden_size = config.hidden_size self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) self.ffn = Mlp(config) self.attn = Attention(config, vis) def forward(self, x): h = x x = self.attention_norm(x) x, weights = self.attn(x) x = x + h h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h return x, weights def load_from(self, weights, n_block): ROOT = f"Transformer/encoderblock_{n_block}" with torch.no_grad(): query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) self.attn.query.weight.copy_(query_weight) self.attn.key.weight.copy_(key_weight) self.attn.value.weight.copy_(value_weight) self.attn.out.weight.copy_(out_weight) self.attn.query.bias.copy_(query_bias) self.attn.key.bias.copy_(key_bias) self.attn.value.bias.copy_(value_bias) self.attn.out.bias.copy_(out_bias) mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() self.ffn.fc1.weight.copy_(mlp_weight_0) self.ffn.fc2.weight.copy_(mlp_weight_1) self.ffn.fc1.bias.copy_(mlp_bias_0) self.ffn.fc2.bias.copy_(mlp_bias_1) self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) class Encoder(nn.Module): def __init__(self, config, vis): super(Encoder, self).__init__() self.vis = vis self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) for _ in range(config.transformer["num_layers"]): layer = Block(config, vis) self.layer.append(copy.deepcopy(layer)) def forward(self, hidden_states): attn_weights = [] for layer_block in self.layer: hidden_states, weights = layer_block(hidden_states) if self.vis: attn_weights.append(weights) encoded = self.encoder_norm(hidden_states) return encoded, attn_weights class Transformer(nn.Module): def __init__(self, config, img_size, vis): super(Transformer, self).__init__() self.embeddings = Embeddings(config, img_size=img_size) self.encoder = Encoder(config, vis) def forward(self, input_ids): embedding_output, features = self.embeddings(input_ids) encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) return encoded, attn_weights, features class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) bn = nn.BatchNorm2d(out_channels) super(Conv2dReLU, self).__init__(conv, bn, relu) class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels=0, use_batchnorm=True, ): super().__init__() self.conv1 = Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class SegmentationHead(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() super().__init__(conv2d, upsampling) class DecoderCup(nn.Module): def __init__(self, config): super().__init__() self.config = config head_channels = 512 self.conv_more = Conv2dReLU( config.hidden_size, head_channels, kernel_size=3, padding=1, use_batchnorm=True, ) decoder_channels = config.decoder_channels in_channels = [head_channels] + list(decoder_channels[:-1]) out_channels = decoder_channels if self.config.n_skip != 0: skip_channels = self.config.skip_channels for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip skip_channels[3-i]=0 else: skip_channels=[0,0,0,0] blocks = [ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) ] self.blocks = nn.ModuleList(blocks) def forward(self, hidden_states, features=None): B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) x = hidden_states.permute(0, 2, 1) x = x.contiguous().view(B, hidden, h, w) x = self.conv_more(x) for i, decoder_block in enumerate(self.blocks): if features is not None: skip = features[i] if (i < self.config.n_skip) else None else: skip = None x = decoder_block(x, skip=skip) return x class VisionTransformer(nn.Module): def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): super(VisionTransformer, self).__init__() self.num_classes = num_classes self.zero_head = zero_head self.classifier = config.classifier self.transformer = Transformer(config, img_size, vis) self.decoder = DecoderCup(config) self.segmentation_head = SegmentationHead( in_channels=config['decoder_channels'][-1], out_channels=config['n_classes'], kernel_size=3, ) self.config = config def forward(self, x): if x.size()[1] == 1: x = x.repeat(1,3,1,1) x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) x = self.decoder(x, features) logits = self.segmentation_head(x) return logits def load_from(self, weights): with torch.no_grad(): res_weight = weights self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) posemb_new = self.transformer.embeddings.position_embeddings if posemb.size() == posemb_new.size(): self.transformer.embeddings.position_embeddings.copy_(posemb) elif posemb.size()[1]-1 == posemb_new.size()[1]: posemb = posemb[:, 1:] self.transformer.embeddings.position_embeddings.copy_(posemb) else: logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) ntok_new = posemb_new.size(1) if self.classifier == "seg": _, posemb_grid = posemb[:, :1], posemb[0, 1:] gs_old = int(np.sqrt(len(posemb_grid))) gs_new = int(np.sqrt(ntok_new)) print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) zoom = (gs_new / gs_old, gs_new / gs_old, 1) posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) posemb = posemb_grid self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) # Encoder whole for bname, block in self.transformer.encoder.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, n_block=uname) if self.transformer.embeddings.hybrid: self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): for uname, unit in block.named_children(): unit.load_from(res_weight, n_block=bname, n_unit=uname) def get_b16_config(): """Returns the ViT-B/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 768 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 3072 config.transformer.num_heads = 12 config.transformer.num_layers = 12 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'seg' config.representation_size = None config.resnet_pretrained_path = None config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' config.patch_size = 16 config.skip_channels = [0, 0, 0, 0] config.decoder_channels = (256, 128, 64, 16) config.n_classes = 2 config.activation = 'softmax' return config def get_testing(): """Returns a minimal configuration for testing.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 1 config.transformer.num_heads = 1 config.transformer.num_layers = 1 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config def get_r50_b16_config(): """Returns the Resnet50 + ViT-B/16 configuration.""" config = get_b16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = 'seg' config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 2 config.n_skip = 3 config.activation = 'softmax' return config def get_b32_config(): """Returns the ViT-B/32 configuration.""" config = get_b16_config() config.patches.size = (32, 32) config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' return config def get_l16_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (16, 16)}) config.hidden_size = 1024 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 4096 config.transformer.num_heads = 16 config.transformer.num_layers = 24 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.representation_size = None # custom config.classifier = 'seg' config.resnet_pretrained_path = None config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' config.decoder_channels = (256, 128, 64, 16) config.n_classes = 2 config.activation = 'softmax' return config def get_r50_l16_config(): """Returns the Resnet50 + ViT-L/16 configuration. customized """ config = get_l16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = 'seg' config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 2 config.activation = 'softmax' return config def get_l32_config(): """Returns the ViT-L/32 configuration.""" config = get_l16_config() config.patches.size = (32, 32) return config def get_h14_config(): """Returns the ViT-L/16 configuration.""" config = ml_collections.ConfigDict() config.patches = ml_collections.ConfigDict({'size': (14, 14)}) config.hidden_size = 1280 config.transformer = ml_collections.ConfigDict() config.transformer.mlp_dim = 5120 config.transformer.num_heads = 16 config.transformer.num_layers = 32 config.transformer.attention_dropout_rate = 0.0 config.transformer.dropout_rate = 0.1 config.classifier = 'token' config.representation_size = None return config CONFIGS = { 'ViT-B_16': get_b16_config(), 'ViT-B_32': get_b32_config(), 'ViT-L_16': get_l16_config(), 'ViT-L_32': get_l32_config(), 'ViT-H_14': get_h14_config(), 'R50-ViT-B_16': get_r50_b16_config(), 'R50-ViT-L_16': get_r50_l16_config(), 'testing': get_testing(), } def np2th(weights, conv=False): """Possibly convert HWIO to OIHW.""" if conv: weights = weights.transpose([3, 2, 0, 1]) return torch.from_numpy(weights) class StdConv2d(nn.Conv2d): def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-5) return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) def conv3x3(cin, cout, stride=1, groups=1, bias=False): return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups) def conv1x1(cin, cout, stride=1, bias=False): return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias) class PreActBottleneck(nn.Module): """Pre-activation (v2) bottleneck block. """ def __init__(self, cin, cout=None, cmid=None, stride=1): super().__init__() cout = cout or cin cmid = cmid or cout//4 self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) self.conv1 = conv1x1(cin, cmid, bias=False) self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) self.conv3 = conv1x1(cmid, cout, bias=False) self.relu = nn.ReLU(inplace=True) if (stride != 1 or cin != cout): # Projection also with pre-activation according to paper. self.downsample = conv1x1(cin, cout, stride, bias=False) self.gn_proj = nn.GroupNorm(cout, cout) def forward(self, x): # Residual branch residual = x if hasattr(self, 'downsample'): residual = self.downsample(x) residual = self.gn_proj(residual) # Unit's branch y = self.relu(self.gn1(self.conv1(x))) y = self.relu(self.gn2(self.conv2(y))) y = self.gn3(self.conv3(y)) y = self.relu(residual + y) return y def load_from(self, weights, n_block, n_unit): conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) self.conv1.weight.copy_(conv1_weight) self.conv2.weight.copy_(conv2_weight) self.conv3.weight.copy_(conv3_weight) self.gn1.weight.copy_(gn1_weight.view(-1)) self.gn1.bias.copy_(gn1_bias.view(-1)) self.gn2.weight.copy_(gn2_weight.view(-1)) self.gn2.bias.copy_(gn2_bias.view(-1)) self.gn3.weight.copy_(gn3_weight.view(-1)) self.gn3.bias.copy_(gn3_bias.view(-1)) if hasattr(self, 'downsample'): proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) self.downsample.weight.copy_(proj_conv_weight) self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) class ResNetV2(nn.Module): """Implementation of Pre-activation (v2) ResNet mode.""" def __init__(self, block_units, width_factor): super().__init__() width = int(64 * width_factor) self.width = width self.root = nn.Sequential(OrderedDict([ ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), ('gn', nn.GroupNorm(32, width, eps=1e-6)), ('relu', nn.ReLU(inplace=True)), # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) ])) self.body = nn.Sequential(OrderedDict([ ('block1', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], ))), ('block2', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], ))), ('block3', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], ))), ])) def forward(self, x): features = [] b, c, in_size, _ = x.size() x = self.root(x) features.append(x) x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) for i in range(len(self.body)-1): x = self.body[i](x) right_size = int(in_size / 4 / (i+1)) if x.size()[2] != right_size: pad = right_size - x.size()[2] assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] else: feat = x features.append(feat) x = self.body[-1](x) return x, features[::-1] ================================================ FILE: model/unet_utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import pdb class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x1(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1): super().__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(inplanes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or inplanes != planes: self.shortcut = nn.Sequential( nn.BatchNorm2d(inplanes), self.relu, nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) ) def forward(self, x): residue = x out = self.bn1(x) out = self.relu(out) out = self.conv1(out) out = self.bn2(out) out = self.relu(out) out = self.conv2(out) out += self.shortcut(residue) return out class BottleneckBlock(nn.Module): def __init__(self, inplanes, planes, stride=1): super().__init__() self.conv1 = conv1x1(inplanes, planes//4, stride=1) self.bn1 = nn.BatchNorm2d(inplanes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes//4, planes//4, stride=stride) self.bn2 = nn.BatchNorm2d(planes//4) self.conv3 = conv1x1(planes//4, planes, stride=1) self.bn3 = nn.BatchNorm2d(planes//4) self.shortcut = nn.Sequential() if stride != 1 or inplanes != planes: self.shortcut = nn.Sequential( nn.BatchNorm2d(inplanes), self.relu, nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) ) def forward(self, x): residue = x out = self.bn1(x) out = self.relu(out) out = self.conv1(out) out = self.bn2(out) out = self.relu(out) out = self.conv2(out) out = self.bn3(out) out = self.relu(out) out = self.conv3(out) out += self.shortcut(residue) return out class inconv(nn.Module): def __init__(self, in_ch, out_ch, bottleneck=False): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False) self.relu = nn.ReLU(inplace=True) if bottleneck: self.conv2 = BottleneckBlock(out_ch, out_ch) else: self.conv2 = BasicBlock(out_ch, out_ch) def forward(self, x): out = self.conv1(x) out = self.conv2(out) return out class down_block(nn.Module): def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, pool=True): super().__init__() block_list = [] if bottleneck: block = BottleneckBlock else: block = BasicBlock if pool: block_list.append(nn.MaxPool2d(scale)) block_list.append(block(in_ch, out_ch)) else: block_list.append(block(in_ch, out_ch, stride=2)) for i in range(num_block-1): block_list.append(block(out_ch, out_ch, stride=1)) self.conv = nn.Sequential(*block_list) def forward(self, x): return self.conv(x) class up_block(nn.Module): def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=False): super().__init__() self.scale=scale self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1) if bottleneck: block = BottleneckBlock else: block = BasicBlock block_list = [] block_list.append(block(2*out_ch, out_ch)) for i in range(num_block-1): block_list.append(block(out_ch, out_ch)) self.conv = nn.Sequential(*block_list) def forward(self, x1, x2): x1 = F.interpolate(x1, scale_factor=self.scale, mode='bilinear', align_corners=True) x1 = self.conv_ch(x1) out = torch.cat([x2, x1], dim=1) out = self.conv(out) return out ================================================ FILE: model/utnet.py ================================================ import torch import torch.nn as nn from .unet_utils import up_block, down_block from .conv_trans_utils import * import pdb class UTNet(nn.Module): def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False): super().__init__() self.aux_loss = aux_loss self.inc = [BasicBlock(in_chan, base_chan)] if '0' in block_list: self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)) self.up4 = up_block_trans(2*base_chan, base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-4], dim_head=base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.inc.append(BasicBlock(base_chan, base_chan)) self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2) self.inc = nn.Sequential(*self.inc) if '1' in block_list: self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) self.up3 = up_block_trans(4*base_chan, 2*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-3], dim_head=2*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2) self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2) if '2' in block_list: self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) self.up2 = up_block_trans(8*base_chan, 4*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-2], dim_head=4*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2) self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2) if '3' in block_list: self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) self.up1 = up_block_trans(16*base_chan, 8*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-1], dim_head=8*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2) self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2) if '4' in block_list: self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2) self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True) if aux_loss: self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True) self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True) self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) if self.aux_loss: out = self.up1(x5, x4) out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up2(out, x3) out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up3(out, x2) out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up4(out, x1) out = self.outc(out) return out, out3, out2, out1 else: out = self.up1(x5, x4) out = self.up2(out, x3) out = self.up3(out, x2) out = self.up4(out, x1) out = self.outc(out) return out class UTNet_Encoderonly(nn.Module): def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False): super().__init__() self.aux_loss = aux_loss self.inc = [BasicBlock(in_chan, base_chan)] if '0' in block_list: self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)) else: self.inc.append(BasicBlock(base_chan, base_chan)) self.inc = nn.Sequential(*self.inc) if '1' in block_list: self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2) if '2' in block_list: self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2) if '3' in block_list: self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2) if '4' in block_list: self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos) else: self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2) self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2) self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2) self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2) self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2) self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True) if aux_loss: self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True) self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True) self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) if self.aux_loss: out = self.up1(x5, x4) out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up2(out, x3) out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up3(out, x2) out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True) out = self.up4(out, x1) out = self.outc(out) return out, out3, out2, out1 else: out = self.up1(x5, x4) out = self.up2(out, x3) out = self.up3(out, x2) out = self.up4(out, x1) out = self.outc(out) return out ================================================ FILE: train_deep.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch import optim import numpy as np from model.utnet import UTNet, UTNet_Encoderonly from dataset_domain import CMRDataset from torch.utils import data from losses import DiceLoss from utils.utils import * from utils import metrics from optparse import OptionParser import SimpleITK as sitk from torch.utils.tensorboard import SummaryWriter import time import math import os import sys import pdb import warnings warnings.filterwarnings("ignore", category=UserWarning) DEBUG = False def train_net(net, options): data_path = options.data_path trainset = CMRDataset(data_path, mode='train', domain=options.domain, debug=DEBUG, scale=options.scale, rotate=options.rotate, crop_size=options.crop_size) trainLoader = data.DataLoader(trainset, batch_size=options.batch_size, shuffle=True, num_workers=16) testset_A = CMRDataset(data_path, mode='test', domain='A', debug=DEBUG, crop_size=options.crop_size) testLoader_A = data.DataLoader(testset_A, batch_size=1, shuffle=False, num_workers=2) testset_B = CMRDataset(data_path, mode='test', domain='B', debug=DEBUG, crop_size=options.crop_size) testLoader_B = data.DataLoader(testset_B, batch_size=1, shuffle=False, num_workers=2) testset_C = CMRDataset(data_path, mode='test', domain='C', debug=DEBUG, crop_size=options.crop_size) testLoader_C = data.DataLoader(testset_C, batch_size=1, shuffle=False, num_workers=2) testset_D = CMRDataset(data_path, mode='test', domain='D', debug=DEBUG, crop_size=options.crop_size) testLoader_D = data.DataLoader(testset_D, batch_size=1, shuffle=False, num_workers=2) writer = SummaryWriter(options.log_path + options.unique_name) optimizer = optim.SGD(net.parameters(), lr=options.lr, momentum=0.9, weight_decay=options.weight_decay) criterion = nn.CrossEntropyLoss(weight=torch.tensor(options.weight).cuda()) criterion_dl = DiceLoss() best_dice = 0 for epoch in range(options.epochs): print('Starting epoch {}/{}'.format(epoch+1, options.epochs)) epoch_loss = 0 exp_scheduler = exp_lr_scheduler_with_warmup(optimizer, init_lr=options.lr, epoch=epoch, warmup_epoch=5, max_epoch=options.epochs) print('current lr:', exp_scheduler) for i, (img, label) in enumerate(trainLoader, 0): img = img.cuda() label = label.cuda() end = time.time() net.train() optimizer.zero_grad() result = net(img) loss = 0 if isinstance(result, tuple) or isinstance(result, list): for j in range(len(result)): loss += options.aux_weight[j] * (criterion(result[j], label.squeeze(1)) + criterion_dl(result[j], label)) else: loss = criterion(result, label.squeeze(1)) + criterion_dl(result, label) loss.backward() optimizer.step() epoch_loss += loss.item() batch_time = time.time() - end print('batch loss: %.5f, batch_time:%.5f'%(loss.item(), batch_time)) print('[epoch %d] epoch loss: %.5f'%(epoch+1, epoch_loss/(i+1))) writer.add_scalar('Train/Loss', epoch_loss/(i+1), epoch+1) writer.add_scalar('LR', exp_scheduler, epoch+1) if os.path.isdir('%s%s/'%(options.cp_path, options.unique_name)): pass else: os.mkdir('%s%s/'%(options.cp_path, options.unique_name)) if epoch % 20 == 0 or epoch > options.epochs-10: torch.save(net.state_dict(), '%s%s/CP%d.pth'%(options.cp_path, options.unique_name, epoch)) if (epoch+1) >90 or (epoch+1) % 10 == 0: dice_list_A, ASD_list_A, HD_list_A = validation(net, testLoader_A, options) log_evaluation_result(writer, dice_list_A, ASD_list_A, HD_list_A, 'A', epoch) dice_list_B, ASD_list_B, HD_list_B = validation(net, testLoader_B, options) log_evaluation_result(writer, dice_list_B, ASD_list_B, HD_list_B, 'B', epoch) dice_list_C, ASD_list_C, HD_list_C = validation(net, testLoader_C, options) log_evaluation_result(writer, dice_list_C, ASD_list_C, HD_list_C, 'C', epoch) dice_list_D, ASD_list_D, HD_list_D = validation(net, testLoader_D, options) log_evaluation_result(writer, dice_list_D, ASD_list_D, HD_list_D, 'D', epoch) AVG_dice_list = 20 * dice_list_A + 50 * dice_list_B + 50 * dice_list_C + 50 * dice_list_D AVG_dice_list /= 170 AVG_ASD_list = 20 * ASD_list_A + 50 * ASD_list_B + 50 * ASD_list_C + 50 * ASD_list_D AVG_ASD_list /= 170 AVG_HD_list = 20 * HD_list_A + 50 * HD_list_B + 50 * HD_list_C + 50 * HD_list_D AVG_HD_list /= 170 log_evaluation_result(writer, AVG_dice_list, AVG_ASD_list, AVG_HD_list, 'mean', epoch) if dice_list_A.mean() >= best_dice: best_dice = dice_list_A.mean() torch.save(net.state_dict(), '%s%s/best.pth'%(options.cp_path, options.unique_name)) print('save done') print('dice: %.5f/best dice: %.5f'%(dice_list_A.mean(), best_dice)) def validation(net, test_loader, options): net.eval() dice_list = np.zeros(3) ASD_list = np.zeros(3) HD_list = np.zeros(3) counter = 0 with torch.no_grad(): for i, (data, label, spacing) in enumerate(test_loader): inputs, labels = data.float().cuda(), label.long().cuda() inputs = inputs.permute(1, 0, 2, 3) labels = labels.permute(1, 0, 2, 3) pred = net(inputs) if options.model == 'FCN_Res50' or options.model == 'FCN_Res101': pred = pred['out'] elif isinstance(pred, tuple): pred = pred[0] pred = F.softmax(pred, dim=1) _, label_pred = torch.max(pred, dim=1) tmp_ASD_list, tmp_HD_list = cal_distance(label_pred, labels, spacing) ASD_list += np.clip(np.nan_to_num(tmp_ASD_list, nan=500), 0, 500) HD_list += np.clip(np.nan_to_num(tmp_HD_list, nan=500), 0, 500) label_pred = label_pred.view(-1, 1) label_true = labels.view(-1, 1) dice, _, _ = cal_dice(label_pred, label_true, 4) dice_list += dice.cpu().numpy()[1:] counter += 1 dice_list /= counter avg_dice = dice_list.mean() ASD_list /= counter HD_list /= counter return dice_list , ASD_list, HD_list def cal_distance(label_pred, label_true, spacing): label_pred = label_pred.squeeze(1).cpu().numpy() label_true = label_true.squeeze(1).cpu().numpy() spacing = spacing.numpy()[0] ASD_list = np.zeros(3) HD_list = np.zeros(3) for i in range(3): tmp_surface = metrics.compute_surface_distances(label_true==(i+1), label_pred==(i+1), spacing) dis_gt_to_pred, dis_pred_to_gt = metrics.compute_average_surface_distance(tmp_surface) ASD_list[i] = (dis_gt_to_pred + dis_pred_to_gt) / 2 HD = metrics.compute_robust_hausdorff(tmp_surface, 100) HD_list[i] = HD return ASD_list, HD_list if __name__ == '__main__': parser = OptionParser() def get_comma_separated_int_args(option, opt, value, parser): value_list = value.split(',') value_list = [int(i) for i in value_list] setattr(parser.values, option.dest, value_list) parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', help='number of epochs') parser.add_option('-b', '--batch_size', dest='batch_size', default=32, type='int', help='batch size') parser.add_option('-l', '--learning-rate', dest='lr', default=0.05, type='float', help='learning rate') parser.add_option('-c', '--resume', type='str', dest='load', default=False, help='load pretrained model') parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', default='./checkpoint/', help='checkpoint path') parser.add_option('--data_path', type='str', dest='data_path', default='/research/cbim/vast/yg397/vision_transformer/dataset/resampled_dataset/', help='dataset path') parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/', help='log path') parser.add_option('-m', type='str', dest='model', default='UTNet', help='use which model') parser.add_option('--num_class', type='int', dest='num_class', default=4, help='number of segmentation classes') parser.add_option('--base_chan', type='int', dest='base_chan', default=32, help='number of channels of first expansion in UNet') parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test', help='unique experiment name') parser.add_option('--rlt', type='float', dest='rlt', default=1, help='relation between CE/FL and dice') parser.add_option('--weight', type='float', dest='weight', default=[0.5,1,1,1] , help='weight each class in loss function') parser.add_option('--weight_decay', type='float', dest='weight_decay', default=0.0001) parser.add_option('--scale', type='float', dest='scale', default=0.30) parser.add_option('--rotate', type='float', dest='rotate', default=180) parser.add_option('--crop_size', type='int', dest='crop_size', default=256) parser.add_option('--domain', type='str', dest='domain', default='A') parser.add_option('--aux_weight', type='float', dest='aux_weight', default=[1, 0.4, 0.2, 0.1]) parser.add_option('--reduce_size', dest='reduce_size', default=8, type='int') parser.add_option('--block_list', dest='block_list', default='1234', type='str') parser.add_option('--num_blocks', dest='num_blocks', default=[1,1,1,1], type='string', action='callback', callback=get_comma_separated_int_args) parser.add_option('--aux_loss', dest='aux_loss', action='store_true', help='using aux loss for deep supervision') parser.add_option('--gpu', type='str', dest='gpu', default='0') options, args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = options.gpu print('Using model:', options.model) if options.model == 'UTNet': net = UTNet(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True) elif options.model == 'UTNet_encoder': # Apply transformer blocks only in the encoder net = UTNet_Encoderonly(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True) elif options.model =='TransUNet': from model.transunet import VisionTransformer as ViT_seg from model.transunet import CONFIGS as CONFIGS_ViT_seg config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] config_vit.n_classes = 4 config_vit.n_skip = 3 config_vit.patches.grid = (int(256/16), int(256/16)) net = ViT_seg(config_vit, img_size=256, num_classes=4) #net.load_from(weights=np.load('./initmodel/R50+ViT-B_16.npz')) # uncomment this to use pretrain model download from TransUnet git repo elif options.model == 'ResNet_UTNet': from model.resnet_utnet import ResNet_UTNet net = ResNet_UTNet(1, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True) elif options.model == 'SwinUNet': from model.swin_unet import SwinUnet, SwinUnet_config config = SwinUnet_config() net = SwinUnet(config, img_size=224, num_classes=options.num_class) net.load_from('./initmodel/swin_tiny_patch4_window7_224.pth') else: raise NotImplementedError(options.model + " has not been implemented") if options.load: net.load_state_dict(torch.load(options.load)) print('Model loaded from {}'.format(options.load)) param_num = sum(p.numel() for p in net.parameters() if p.requires_grad) print(net) print(param_num) net.cuda() train_net(net, options) print('done') sys.exit(0) ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/lookup_tables.py ================================================ # Copyright 2018 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS-IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Lookup tables used by surface distance metrics.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import numpy as np ENCODE_NEIGHBOURHOOD_3D_KERNEL = np.array([[[128, 64], [32, 16]], [[8, 4], [2, 1]]]) # _NEIGHBOUR_CODE_TO_NORMALS is a lookup table. # For every binary neighbour code # (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) # it contains the surface normals of the triangles (called "surfel" for # "surface element" in the following). The length of the normal # vector encodes the surfel area. # # created using the marching_cube algorithm # see e.g. https://en.wikipedia.org/wiki/Marching_cubes # pylint: disable=line-too-long _NEIGHBOUR_CODE_TO_NORMALS = [ [[0, 0, 0]], [[0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], [[0.125, -0.125, 0.125]], [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]], [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]], [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]], [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]], [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]], [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]], [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]], [[0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]], [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]], [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]], [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]], [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]], [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]], [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]], [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]], [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]], [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]], [[-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]], [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]], [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]], [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]], [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]], [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]], [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]], [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]], [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]], [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]], [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]], [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]], [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]], [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]], [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]], [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]], [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]], [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]], [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]], [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]], [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]], [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]], [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]], [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]], [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]], [[0.125, -0.125, 0.125]], [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]], [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]], [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]], [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]], [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]], [[0.125, -0.125, -0.125]], [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]], [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]], [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]], [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]], [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]], [[-0.125, 0.125, 0.125]], [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]], [[0.125, -0.125, 0.125]], [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], [[-0.125, -0.125, 0.125]], [[0.125, 0.125, 0.125]], [[0, 0, 0]]] # pylint: enable=line-too-long def create_table_neighbour_code_to_surface_area(spacing_mm): """Returns an array mapping neighbourhood code to the surface elements area. Note that the normals encode the initial surface area. This function computes the area corresponding to the given `spacing_mm`. Args: spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2 direction. """ # compute the area for all 256 possible surface elements # (given a 2x2x2 neighbourhood) according to the spacing_mm neighbour_code_to_surface_area = np.zeros([256]) for code in range(256): normals = np.array(_NEIGHBOUR_CODE_TO_NORMALS[code]) sum_area = 0 for normal_idx in range(normals.shape[0]): # normal vector n = np.zeros([3]) n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2] n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2] n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1] area = np.linalg.norm(n) sum_area += area neighbour_code_to_surface_area[code] = sum_area return neighbour_code_to_surface_area # In the neighbourhood, points are ordered: top left, top right, bottom left, # bottom right. ENCODE_NEIGHBOURHOOD_2D_KERNEL = np.array([[8, 4], [2, 1]]) def create_table_neighbour_code_to_contour_length(spacing_mm): """Returns an array mapping neighbourhood code to the contour length. For the list of possible cases and their figures, see page 38 from: https://nccastaff.bournemouth.ac.uk/jmacey/MastersProjects/MSc14/06/thesis.pdf In 2D, each point has 4 neighbors. Thus, are 16 configurations. A configuration is encoded with '1' meaning "inside the object" and '0' "outside the object". The points are ordered: top left, top right, bottom left, bottom right. The x0 axis is assumed vertical downward, and the x1 axis is horizontal to the right: (0, 0) --> (0, 1) | (1, 0) Args: spacing_mm: 2-element list-like structure. Voxel spacing in x0 and x1 directions. """ neighbour_code_to_contour_length = np.zeros([16]) vertical = spacing_mm[0] horizontal = spacing_mm[1] diag = 0.5 * math.sqrt(spacing_mm[0]**2 + spacing_mm[1]**2) # pyformat: disable neighbour_code_to_contour_length[int("00" "01", 2)] = diag neighbour_code_to_contour_length[int("00" "10", 2)] = diag neighbour_code_to_contour_length[int("00" "11", 2)] = horizontal neighbour_code_to_contour_length[int("01" "00", 2)] = diag neighbour_code_to_contour_length[int("01" "01", 2)] = vertical neighbour_code_to_contour_length[int("01" "10", 2)] = 2*diag neighbour_code_to_contour_length[int("01" "11", 2)] = diag neighbour_code_to_contour_length[int("10" "00", 2)] = diag neighbour_code_to_contour_length[int("10" "01", 2)] = 2*diag neighbour_code_to_contour_length[int("10" "10", 2)] = vertical neighbour_code_to_contour_length[int("10" "11", 2)] = diag neighbour_code_to_contour_length[int("11" "00", 2)] = horizontal neighbour_code_to_contour_length[int("11" "01", 2)] = diag neighbour_code_to_contour_length[int("11" "10", 2)] = diag # pyformat: enable return neighbour_code_to_contour_length ================================================ FILE: utils/metrics.py ================================================ # Copyright 2018 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS-IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module exposing surface distance based measures.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from . import lookup_tables # pylint: disable=relative-beyond-top-level import numpy as np from scipy import ndimage def _assert_is_numpy_array(name, array): """Raises an exception if `array` is not a numpy array.""" if not isinstance(array, np.ndarray): raise ValueError("The argument {!r} should be a numpy array, not a " "{}".format(name, type(array))) def _check_nd_numpy_array(name, array, num_dims): """Raises an exception if `array` is not a `num_dims`-D numpy array.""" if len(array.shape) != num_dims: raise ValueError("The argument {!r} should be a {}D array, not of " "shape {}".format(name, num_dims, array.shape)) def _check_2d_numpy_array(name, array): _check_nd_numpy_array(name, array, num_dims=2) def _check_3d_numpy_array(name, array): _check_nd_numpy_array(name, array, num_dims=3) def _assert_is_bool_numpy_array(name, array): _assert_is_numpy_array(name, array) if array.dtype != np.bool: raise ValueError("The argument {!r} should be a numpy array of type bool, " "not {}".format(name, array.dtype)) def _compute_bounding_box(mask): """Computes the bounding box of the masks. This function generalizes to arbitrary number of dimensions great or equal to 1. Args: mask: The 2D or 3D numpy mask, where '0' means background and non-zero means foreground. Returns: A tuple: - The coordinates of the first point of the bounding box (smallest on all axes), or `None` if the mask contains only zeros. - The coordinates of the second point of the bounding box (greatest on all axes), or `None` if the mask contains only zeros. """ num_dims = len(mask.shape) bbox_min = np.zeros(num_dims, np.int64) bbox_max = np.zeros(num_dims, np.int64) # max projection to the x0-axis proj_0 = np.amax(mask, axis=tuple(range(num_dims))[1:]) idx_nonzero_0 = np.nonzero(proj_0)[0] if len(idx_nonzero_0) == 0: # pylint: disable=g-explicit-length-test return None, None bbox_min[0] = np.min(idx_nonzero_0) bbox_max[0] = np.max(idx_nonzero_0) # max projection to the i-th-axis for i in {1, ..., num_dims - 1} for axis in range(1, num_dims): max_over_axes = list(range(num_dims)) # Python 3 compatible max_over_axes.pop(axis) # Remove the i-th dimension from the max max_over_axes = tuple(max_over_axes) # numpy expects a tuple of ints proj = np.amax(mask, axis=max_over_axes) idx_nonzero = np.nonzero(proj)[0] bbox_min[axis] = np.min(idx_nonzero) bbox_max[axis] = np.max(idx_nonzero) return bbox_min, bbox_max def _crop_to_bounding_box(mask, bbox_min, bbox_max): """Crops a 2D or 3D mask to the bounding box specified by `bbox_{min,max}`.""" # we need to zeropad the cropped region with 1 voxel at the lower, # the right (and the back on 3D) sides. This is required to obtain the # "full" convolution result with the 2x2 (or 2x2x2 in 3D) kernel. #TODO: This is correct only if the object is interior to the # bounding box. cropmask = np.zeros((bbox_max - bbox_min) + 2, np.uint8) num_dims = len(mask.shape) # pyformat: disable if num_dims == 2: cropmask[0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1, bbox_min[1]:bbox_max[1] + 1] elif num_dims == 3: cropmask[0:-1, 0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1, bbox_min[1]:bbox_max[1] + 1, bbox_min[2]:bbox_max[2] + 1] # pyformat: enable else: assert False return cropmask def _sort_distances_surfels(distances, surfel_areas): """Sorts the two list with respect to the tuple of (distance, surfel_area). Args: distances: The distances from A to B (e.g. `distances_gt_to_pred`). surfel_areas: The surfel areas for A (e.g. `surfel_areas_gt`). Returns: A tuple of the sorted (distances, surfel_areas). """ sorted_surfels = np.array(sorted(zip(distances, surfel_areas))) return sorted_surfels[:, 0], sorted_surfels[:, 1] def compute_surface_distances(mask_gt, mask_pred, spacing_mm): """Computes closest distances from all surface points to the other surface. This function can be applied to 2D or 3D tensors. For 2D, both masks must be 2D and `spacing_mm` must be a 2-element list. For 3D, both masks must be 3D and `spacing_mm` must be a 3-element list. The description is done for the 2D case, and the formulation for the 3D case is present is parenthesis, introduced by "resp.". Finds all contour elements (resp surface elements "surfels" in 3D) in the ground truth mask `mask_gt` and the predicted mask `mask_pred`, computes their length in mm (resp. area in mm^2) and the distance to the closest point on the other contour (resp. surface). It returns two sorted lists of distances together with the corresponding contour lengths (resp. surfel areas). If one of the masks is empty, the corresponding lists are empty and all distances in the other list are `inf`. Args: mask_gt: 2-dim (resp. 3-dim) bool Numpy array. The ground truth mask. mask_pred: 2-dim (resp. 3-dim) bool Numpy array. The predicted mask. spacing_mm: 2-element (resp. 3-element) list-like structure. Voxel spacing in x0 anx x1 (resp. x0, x1 and x2) directions. Returns: A dict with: "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm from all ground truth surface elements to the predicted surface, sorted from smallest to largest. "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm from all predicted surface elements to the ground truth surface, sorted from smallest to largest. "surfel_areas_gt": 1-dim numpy array of type float. The length of the of the ground truth contours in mm (resp. the surface elements area in mm^2) in the same order as distances_gt_to_pred. "surfel_areas_pred": 1-dim numpy array of type float. The length of the of the predicted contours in mm (resp. the surface elements area in mm^2) in the same order as distances_gt_to_pred. Raises: ValueError: If the masks and the `spacing_mm` arguments are of incompatible shape or type. Or if the masks are not 2D or 3D. """ # The terms used in this function are for the 3D case. In particular, surface # in 2D stands for contours in 3D. The surface elements in 3D correspond to # the line elements in 2D. _assert_is_bool_numpy_array("mask_gt", mask_gt) _assert_is_bool_numpy_array("mask_pred", mask_pred) if not len(mask_gt.shape) == len(mask_pred.shape) == len(spacing_mm): raise ValueError("The arguments must be of compatible shape. Got mask_gt " "with {} dimensions ({}) and mask_pred with {} dimensions " "({}), while the spacing_mm was {} elements.".format( len(mask_gt.shape), mask_gt.shape, len(mask_pred.shape), mask_pred.shape, len(spacing_mm))) num_dims = len(spacing_mm) if num_dims == 2: _check_2d_numpy_array("mask_gt", mask_gt) _check_2d_numpy_array("mask_pred", mask_pred) # compute the area for all 16 possible surface elements # (given a 2x2 neighbourhood) according to the spacing_mm neighbour_code_to_surface_area = ( lookup_tables.create_table_neighbour_code_to_contour_length(spacing_mm)) kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_2D_KERNEL full_true_neighbours = 0b1111 elif num_dims == 3: _check_3d_numpy_array("mask_gt", mask_gt) _check_3d_numpy_array("mask_pred", mask_pred) # compute the area for all 256 possible surface elements # (given a 2x2x2 neighbourhood) according to the spacing_mm neighbour_code_to_surface_area = ( lookup_tables.create_table_neighbour_code_to_surface_area(spacing_mm)) kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_3D_KERNEL full_true_neighbours = 0b11111111 else: raise ValueError("Only 2D and 3D masks are supported, not " "{}D.".format(num_dims)) # compute the bounding box of the masks to trim the volume to the smallest # possible processing subvolume bbox_min, bbox_max = _compute_bounding_box(mask_gt | mask_pred) # Both the min/max bbox are None at the same time, so we only check one. if bbox_min is None: return { "distances_gt_to_pred": np.array([]), "distances_pred_to_gt": np.array([]), "surfel_areas_gt": np.array([]), "surfel_areas_pred": np.array([]), } # crop the processing subvolume. cropmask_gt = _crop_to_bounding_box(mask_gt, bbox_min, bbox_max) cropmask_pred = _crop_to_bounding_box(mask_pred, bbox_min, bbox_max) # compute the neighbour code (local binary pattern) for each voxel # the resulting arrays are spacially shifted by minus half a voxel in each # axis. # i.e. the points are located at the corners of the original voxels neighbour_code_map_gt = ndimage.filters.correlate( cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0) neighbour_code_map_pred = ndimage.filters.correlate( cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0) # create masks with the surface voxels borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != full_true_neighbours)) borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != full_true_neighbours)) # compute the distance transform (closest distance of each voxel to the # surface voxels) if borders_gt.any(): distmap_gt = ndimage.morphology.distance_transform_edt( ~borders_gt, sampling=spacing_mm) else: distmap_gt = np.Inf * np.ones(borders_gt.shape) if borders_pred.any(): distmap_pred = ndimage.morphology.distance_transform_edt( ~borders_pred, sampling=spacing_mm) else: distmap_pred = np.Inf * np.ones(borders_pred.shape) # compute the area of each surface element surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt] surface_area_map_pred = neighbour_code_to_surface_area[ neighbour_code_map_pred] # create a list of all surface elements with distance and area distances_gt_to_pred = distmap_pred[borders_gt] distances_pred_to_gt = distmap_gt[borders_pred] surfel_areas_gt = surface_area_map_gt[borders_gt] surfel_areas_pred = surface_area_map_pred[borders_pred] # sort them by distance if distances_gt_to_pred.shape != (0,): distances_gt_to_pred, surfel_areas_gt = _sort_distances_surfels( distances_gt_to_pred, surfel_areas_gt) if distances_pred_to_gt.shape != (0,): distances_pred_to_gt, surfel_areas_pred = _sort_distances_surfels( distances_pred_to_gt, surfel_areas_pred) return { "distances_gt_to_pred": distances_gt_to_pred, "distances_pred_to_gt": distances_pred_to_gt, "surfel_areas_gt": surfel_areas_gt, "surfel_areas_pred": surfel_areas_pred, } def compute_average_surface_distance(surface_distances): """Returns the average surface distance. Computes the average surface distances by correctly taking the area of each surface element into account. Call compute_surface_distances(...) before, to obtain the `surface_distances` dict. Args: surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt" "surfel_areas_gt", "surfel_areas_pred" created by compute_surface_distances() Returns: A tuple with two float values: - the average distance (in mm) from the ground truth surface to the predicted surface - the average distance from the predicted surface to the ground truth surface. """ distances_gt_to_pred = surface_distances["distances_gt_to_pred"] distances_pred_to_gt = surface_distances["distances_pred_to_gt"] surfel_areas_gt = surface_distances["surfel_areas_gt"] surfel_areas_pred = surface_distances["surfel_areas_pred"] average_distance_gt_to_pred = ( np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt)) average_distance_pred_to_gt = ( np.sum(distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred)) return (average_distance_gt_to_pred, average_distance_pred_to_gt) def compute_robust_hausdorff(surface_distances, percent): """Computes the robust Hausdorff distance. Computes the robust Hausdorff distance. "Robust", because it uses the `percent` percentile of the distances instead of the maximum distance. The percentage is computed by correctly taking the area of each surface element into account. Args: surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt" "surfel_areas_gt", "surfel_areas_pred" created by compute_surface_distances() percent: a float value between 0 and 100. Returns: a float value. The robust Hausdorff distance in mm. """ distances_gt_to_pred = surface_distances["distances_gt_to_pred"] distances_pred_to_gt = surface_distances["distances_pred_to_gt"] surfel_areas_gt = surface_distances["surfel_areas_gt"] surfel_areas_pred = surface_distances["surfel_areas_pred"] if len(distances_gt_to_pred) > 0: # pylint: disable=g-explicit-length-test surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt) idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0) perc_distance_gt_to_pred = distances_gt_to_pred[ min(idx, len(distances_gt_to_pred)-1)] else: perc_distance_gt_to_pred = np.Inf if len(distances_pred_to_gt) > 0: # pylint: disable=g-explicit-length-test surfel_areas_cum_pred = (np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred)) idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0) perc_distance_pred_to_gt = distances_pred_to_gt[ min(idx, len(distances_pred_to_gt)-1)] else: perc_distance_pred_to_gt = np.Inf return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt) def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm): """Computes the overlap of the surfaces at a specified tolerance. Computes the overlap of the ground truth surface with the predicted surface and vice versa allowing a specified tolerance (maximum surface-to-surface distance that is regarded as overlapping). The overlapping fraction is computed by correctly taking the area of each surface element into account. Args: surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt" "surfel_areas_gt", "surfel_areas_pred" created by compute_surface_distances() tolerance_mm: a float value. The tolerance in mm Returns: A tuple of two float values. The overlap fraction in [0.0, 1.0] of the ground truth surface with the predicted surface and vice versa. """ distances_gt_to_pred = surface_distances["distances_gt_to_pred"] distances_pred_to_gt = surface_distances["distances_pred_to_gt"] surfel_areas_gt = surface_distances["surfel_areas_gt"] surfel_areas_pred = surface_distances["surfel_areas_pred"] rel_overlap_gt = ( np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt)) rel_overlap_pred = ( np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred)) return (rel_overlap_gt, rel_overlap_pred) def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm): """Computes the _surface_ DICE coefficient at a specified tolerance. Computes the _surface_ DICE coefficient at a specified tolerance. Not to be confused with the standard _volumetric_ DICE coefficient. The surface DICE measures the overlap of two surfaces instead of two volumes. A surface element is counted as overlapping (or touching), when the closest distance to the other surface is less or equal to the specified tolerance. The DICE coefficient is in the range between 0.0 (no overlap) to 1.0 (perfect overlap). Args: surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt" "surfel_areas_gt", "surfel_areas_pred" created by compute_surface_distances() tolerance_mm: a float value. The tolerance in mm Returns: A float value. The surface DICE coefficient in [0.0, 1.0]. """ distances_gt_to_pred = surface_distances["distances_gt_to_pred"] distances_pred_to_gt = surface_distances["distances_pred_to_gt"] surfel_areas_gt = surface_distances["surfel_areas_gt"] surfel_areas_pred = surface_distances["surfel_areas_pred"] overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) surface_dice = (overlap_gt + overlap_pred) / ( np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred)) return surface_dice def compute_dice_coefficient(mask_gt, mask_pred): """Computes soerensen-dice coefficient. compute the soerensen-dice coefficient between the ground truth mask `mask_gt` and the predicted mask `mask_pred`. Args: mask_gt: 3-dim Numpy array of type bool. The ground truth mask. mask_pred: 3-dim Numpy array of type bool. The predicted mask. Returns: the dice coeffcient as float. If both masks are empty, the result is NaN. """ volume_sum = mask_gt.sum() + mask_pred.sum() if volume_sum == 0: return np.NaN volume_intersect = (mask_gt & mask_pred).sum() return 2*volume_intersect / volume_sum ================================================ FILE: utils/utils.py ================================================ import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import SimpleITK as sitk import pdb def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, epoch): writer.add_scalar('Test_Dice/%s_AVG'%name, dice_list.mean(), epoch+1) for idx in range(3): writer.add_scalar('Test_Dice/%s_Dice%d'%(name, idx+1), dice_list[idx], epoch+1) writer.add_scalar('Test_ASD/%s_AVG'%name, ASD_list.mean(), epoch+1) for idx in range(3): writer.add_scalar('Test_ASD/%s_ASD%d'%(name, idx+1), ASD_list[idx], epoch+1) writer.add_scalar('Test_HD/%s_AVG'%name, HD_list.mean(), epoch+1) for idx in range(3): writer.add_scalar('Test_HD/%s_HD%d'%(name, idx+1), HD_list[idx], epoch+1) def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1): if epoch >= 0 and epoch <= warmup_epoch: lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) if epoch == warmup_epoch: lr = init_lr for param_group in optimizer.param_groups: param_group['lr'] = lr return lr flag = False for i in range(len(lr_decay_epoch)): if epoch == lr_decay_epoch[i]: flag = True break if flag == True: lr = init_lr * gamma**(i+1) for param_group in optimizer.param_groups: param_group['lr'] = lr else: return optimizer.param_groups[0]['lr'] return lr def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch): if epoch >= 0 and epoch <= warmup_epoch: lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) if epoch == warmup_epoch: lr = init_lr for param_group in optimizer.param_groups: param_group['lr'] = lr return lr else: lr = init_lr * (1 - epoch / max_epoch)**0.9 for param_group in optimizer.param_groups: param_group['lr'] = lr return lr def cal_dice(pred, target, C): N = pred.shape[0] target_mask = target.data.new(N, C).fill_(0) target_mask.scatter_(1, target, 1.) pred_mask = pred.data.new(N, C).fill_(0) pred_mask.scatter_(1, pred, 1.) intersection= pred_mask * target_mask summ = pred_mask + target_mask intersection = intersection.sum(0).type(torch.float32) summ = summ.sum(0).type(torch.float32) eps = torch.rand(C, dtype=torch.float32) eps = eps.fill_(1e-7) summ += eps.cuda() dice = 2 * intersection / summ return dice, intersection, summ def cal_asd(itkPred, itkGT): reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkGT, squaredDistance=False)) reference_surface = sitk.LabelContour(itkGT) statistics_image_filter = sitk.StatisticsImageFilter() statistics_image_filter.Execute(reference_surface) num_reference_surface_pixels = int(statistics_image_filter.GetSum()) segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkPred, squaredDistance=False)) segmented_surface = sitk.LabelContour(itkPred) seg2ref_distance_map = reference_distance_map * sitk.Cast(segmented_surface, sitk.sitkFloat32) ref2seg_distance_map = segmented_distance_map * sitk.Cast(reference_surface, sitk.sitkFloat32) statistics_image_filter.Execute(segmented_surface) num_segmented_surface_pixels = int(statistics_image_filter.GetSum()) seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map) seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0]) seg2ref_distances = seg2ref_distances + \ list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances))) ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map) ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0]) ref2seg_distances = ref2seg_distances + \ list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances))) all_surface_distances = seg2ref_distances + ref2seg_distances ASD = np.mean(all_surface_distances) return ASD