Repository: yhygao/UTNet
Branch: main
Commit: 6bdcc97e2364
Files: 27
Total size: 190.5 KB
Directory structure:
gitextract_vjhzdth0/
├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess.py
├── dataset/
│ ├── Testing_A.csv
│ ├── Testing_B.csv
│ ├── Testing_C.csv
│ ├── Testing_D.csv
│ ├── Training_A.csv
│ ├── Training_B.csv
│ ├── Training_C.csv
│ ├── Training_D.csv
│ └── info.csv
├── dataset_domain.py
├── losses.py
├── model/
│ ├── __init__.py
│ ├── conv_trans_utils.py
│ ├── resnet_utnet.py
│ ├── swin_unet.py
│ ├── transunet.py
│ ├── unet_utils.py
│ └── utnet.py
├── train_deep.py
└── utils/
├── __init__.py
├── lookup_tables.py
├── metrics.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
*.pyc
*_bkp.py
checkpoint/
log/
model/backup/
initmodel/
*.swp
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 yhygao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# UTNet (Accepted at MICCAI 2021)
Official implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation](https://arxiv.org/abs/2107.00781)
## Update
* Our new work, Hermes, has been released on arXiv: [Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning](https://arxiv.org/pdf/2306.02416.pdf). Inspired by the training of medical residents, we explore universal medical image segmentation, whose goal is to learn from diverse medical imaging sources covering a range of clinical targets, body regions, and image modalities. Following this paradigm, we propose Hermes, a context prior learning approach that addresses the challenges related to the heterogeneity on data, modality, and annotations in the proposed universal paradigm. Code will be released at https://github.com/yhygao/universal-medical-image-segmentation.
* Our new paper, the improved version of UTNet: UTNetV2, is released on Arxiv: [A Multi-scale Transformer for Medical Image Segmentation: Architectures, Model Efficiency, and Benchmarks](https://arxiv.org/abs/2203.00131). The UTNetV2 has an improved architecture for 2D and 3D setting. We also provide a more general framework for CNN and Transformer comparison, including more dataset support, more SOTA model support, see in our new repo:https://github.com/yhygao/CBIM-Medical-Image-Segmentation
Data preprocess code uploaded.
## Introduction
Transformer architecture has emerged to be successful in a
number of natural language processing tasks. However, its applications
to medical vision remain largely unexplored. In this study, we present
UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing
medical image segmentation. UTNet applies self-attention modules in
both encoder and decoder for capturing long-range dependency at dif-
ferent scales with minimal overhead. To this end, we propose an efficient
self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2)
to approximate O(n). A new self-attention decoder is also proposed to
recover fine-grained details from the skipped connections in the encoder.
Our approach addresses the dilemma that Transformer requires huge
amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks
without a need of pre-training. We have evaluated UTNet on the multi-
label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet
demonstrates superior segmentation performance and robustness against
the state-of-the-art approaches, holding the promise to generalize well on
other medical image segmentations.


## Supportting models
UTNet
TransUNet
ResNet50-UTNet
ResNet50-UNet
SwinUNet
To be continue ...
## Getting Started
Currently, we only support [M&Ms dataset](https://www.ub.edu/mnms/).
### Prerequisites
```
Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2
```
### Preprocess
Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'
### Training
The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.
#### UTNet
For default UTNet setting, training with:
```
python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss
```
Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.
To optimize UTNet in your own task, there are several hyperparameters to tune:
'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.
'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.
'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.
'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.
Here are some recomended parameter setting:
```
--block_list 1234 --num_blocks 1,1,1,1
```
Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.
```
--block_list 1234 --num_blocks 1,1,4,8
```
Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.
```
--block_list 234 --num_blocks 2,4,8
```
Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.
Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.
#### TransUNet
We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the [original repo](https://github.com/Beckschen/TransUNet). The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.
```
python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0
```
#### ResNet50-UTNet
For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.
```
python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0
```
Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.
```
--block_list 23 --num_blocks 2,4
```
Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.
#### ResNet50-UNet
If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.
```
python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list '' --gpu 0
```
#### SwinUNet
Download pre-trained model from the [origin repo](https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912).
As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.
```
python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224
```
## Citation
If you find this repo helps, please kindly cite our paper, thanks!
```
@inproceedings{gao2021utnet,
title={UTNet: a hybrid transformer architecture for medical image segmentation},
author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={61--71},
year={2021},
organization={Springer}
}
@misc{gao2022datascalable,
title={A Data-scalable Transformer for Medical Image Segmentation: Architecture, Model Efficiency, and Benchmark},
author={Yunhe Gao and Mu Zhou and Di Liu and Zhennan Yan and Shaoting Zhang and Dimitris N. Metaxas},
year={2022},
eprint={2203.00131},
archivePrefix={arXiv},
primaryClass={eess.IV}
}
```
================================================
FILE: data_preprocess.py
================================================
import numpy as np
import SimpleITK as sitk
import os
import pdb
def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear):
identity1 = sitk.Transform(3, sitk.sitkIdentity)
sp1 = imImage.GetSpacing()
sz1 = imImage.GetSize()
sz2 = (int(round(sz1[0]*sp1[0]*1.0/space[0])), int(round(sz1[1]*sp1[1]*1.0/space[1])), int(round(sz1[2]*sp1[2]*1.0/space[2])))
imRefImage = sitk.Image(sz2, imImage.GetPixelIDValue())
imRefImage.SetSpacing(space)
imRefImage.SetOrigin(imImage.GetOrigin())
imRefImage.SetDirection(imImage.GetDirection())
imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)
return imOutImage
def ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighbor):
identity1 = sitk.Transform(3, sitk.sitkIdentity)
imRefImage = sitk.Image(imRef.GetSize(), imImage.GetPixelIDValue())
imRefImage.SetSpacing(imRef.GetSpacing())
imRefImage.SetOrigin(imRef.GetOrigin())
imRefImage.SetDirection(imRef.GetDirection())
imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)
return imOutImage
def ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, target_space=(1., 1.)):
assert imImage.GetSpacing() == imLabel.GetSpacing()
assert imImage.GetSize() == imLabel.GetSize()
spacing = imImage.GetSpacing()
origin = imImage.GetOrigin()
npimg = sitk.GetArrayFromImage(imImage)
nplab = sitk.GetArrayFromImage(imLabel)
t, z, y, x = npimg.shape
if not os.path.exists('%s/%s'%(save_path, patient_name)):
os.mkdir('%s/%s'%(save_path, patient_name))
flag = 0
for i in range(t):
tmp_img = npimg[i]
tmp_lab = nplab[i]
if tmp_lab.max() == 0:
continue
print(i)
flag += 1
tmp_itkimg = sitk.GetImageFromArray(tmp_img)
tmp_itkimg.SetSpacing(spacing[0:3])
tmp_itkimg.SetOrigin(origin[0:3])
tmp_itklab = sitk.GetImageFromArray(tmp_lab)
tmp_itklab.SetSpacing(spacing[0:3])
tmp_itklab.SetOrigin(origin[0:3])
re_img = ResampleXYZAxis(tmp_itkimg, space=(target_space[0], target_space[1], spacing[2]))
re_lab = ResampleFullImageToRef(tmp_itklab, re_img)
sitk.WriteImage(re_img, '%s/%s/%s_%d.nii.gz'%(save_path, patient_name, name, flag))
sitk.WriteImage(re_lab, '%s/%s/%s_gt_%d.nii.gz'%(save_path, patient_name, name, flag))
return flag
if __name__ == '__main__':
src_path = 'OpenDataset/Training/Labeled'
tgt_path = 'dataset/Training'
os.chdir(src_path)
for name in os.listdir('.'):
os.chdir(name)
for i in os.listdir('.'):
if 'gt' in i:
tmp = i.split('_')
img_name = tmp[0] + '_' + tmp[1]
patient_name = tmp[0]
img = sitk.ReadImage('%s.nii.gz'%img_name)
lab = sitk.ReadImage('%s_gt.nii.gz'%img_name)
flag = ResampleCMRImage(img, lab, tgt_path, patient_name, img_name, (1.2, 1.2))
print(name, 'done', flag)
os.chdir('..')
================================================
FILE: dataset/Testing_A.csv
================================================
name
A2L1N6
B0H7V0
B3F0V9
B5F8L9
B9H8N8
C4E9I1
C6E0F9
C6U4W8
D1H6U2
D9F5P1
E0J7L9
F0K4T6
G7S6V0
G9N5V9
H1N8S6
I2J6Z6
I5L3S2
J4J8Q3
K6N4N7
P3P9S5
================================================
FILE: dataset/Testing_B.csv
================================================
name
A2H5K9
A4A8V9
A5C2D2
A5D0G0
A6J0Y2
A7E4J0
B1G9J3
B3S2Z4
B4E1K1
B4S1Y2
B5L5Y4
B5T6V0
B7F5P0
C7L8Z8
C8I7P7
D1R0Y5
D3Q0W9
D6E9U8
D8O0W2
D9I8O7
E1L8Y4
E3L8U8
E4I9O7
E4O8P3
E5S7W7
E6J4N8
G1K1V3
G3M5S4
G7N8R7
G7Q2W0
H7K5U5
H8K2K7
I0I2J8
I6T4W8
I7W4Y8
J9L4S2
K5K6N1
K7L2Y6
K9N0W0
L7Y7Z2
L8N7Z0
M4T4V6
M6V2Y0
N9P5Z0
O4T6Y7
O9V8W5
P3R6Y5
P8W4Z0
Q3Q6R8
Y6Y9Z2
================================================
FILE: dataset/Testing_C.csv
================================================
name
A3H5R1
A5H1Q2
A8C5E9
A9F3T5
A9L7Y7
B0L3Y2
B2L0L2
B8P5Q9
C0N8P4
C7M6W0
C8J7L5
C8O0P2
D1H2O9
D2U0V0
D3K5Q2
D5G3W8
D6N7Q8
D7M8P9
D7T3V8
E3F2U7
E3F5U2
E6M6P2
E7L0N6
E9V9Z2
F0I6U8
F1K2S9
F5I1Z8
G8R0Z9
H2M9S1
H3R6S9
H7L8R8
H7P5Z4
I4R8V6
I6P4R0
I8Z0Z6
J6K4V3
K3R0Y7
K7O3Q0
L2V5Z0
L8N7P0
M6M9N1
O4O6U5
P3T5U1
Q1Q3T1
Q5V8W3
R1R6Y8
R2R7Z5
R6V5W3
R8V0Y4
V4W8Z5
================================================
FILE: dataset/Testing_D.csv
================================================
name
A1K2P5
A3P9V7
A4B9O6
A4K8R4
A4R4T0
A5P5W0
A5Q1W8
A6A8H0
A6B7Y4
A7F4G2
B3E2W8
B6I0T4
B9G4U2
C0L7V1
C5L0R0
D1L6T4
D1S5T8
E0J2Z9
E1L7M3
E4H7L4
E5J6L2
E6H0V9
F6J9L9
F9M1R2
G1J5K3
G4I7V2
G8K0M3
G8L0Z0
H6P7T1
I4L4V7
I8N8Y1
J6M5O2
K3P3Y6
K5L4S1
K5M7V5
K7N0R7
L5U7Y4
L6T2T5
L8M2U8
M2P5T8
N2O7U5
N7P3T8
N7W6Z8
N9Q4T8
O5U2U7
O7Q7U3
P8V0Y7
Q4W5Z8
R3V5W7
T2Z1Z9
================================================
FILE: dataset/Training_A.csv
================================================
name
A0S9V9
A1D9Z7
A1E9Q1
A2C0I1
A2N8V0
A3H1O5
A4B5U4
A4J4S4
A4U9V5
A6B5G9
A6D5F9
A7D9L8
A7M7P8
A7O4T6
B0I2Z0
B2C2Z7
B2D9M2
B2G5R2
B3O1S0
B3P3R1
B4O3V3
B8H5H6
B8J7R4
B9E0Q1
C0K1P0
C1K8P5
C2J0K3
C5M4S2
C6J5P1
C8P3S7
D0R0R9
D1J5P6
D1M1S6
D3D4Y5
D3F3O5
D3F9H9
D3O9U9
D4M3Q2
D4N6W6
D8E4F4
D9L1Z3
E0M3U7
E0O0S0
E9H1U4
E9V4Z8
F2H5S1
F3G5K5
G4S9U3
G5P4U3
G7I5V7
H1I3W0
H5N0P0
I7T3U1
J1T9Y1
J6K6P5
J6P5T8
J8R5W2
K4T7Y0
K5L2U3
K5P0Y1
L1Q9V8
L4Q2U3
M1R4S1
M2P1R1
N5S7Y1
N8N9U0
O0S9V7
P0S5Y0
P5R1Y4
Q0U0V5
Q3R9W7
R4Y1Z9
S1S3Z7
T2T9Z9
T9U9W2
================================================
FILE: dataset/Training_B.csv
================================================
name
A1D0Q7
A1O8Z3
A3B7E5
A5E0T8
A6M1Q7
A7G0P5
A8C9U8
A8E1F4
A9C5P4
A9E3G9
A9J5Q7
A9J8W7
B0N3W8
B2D9O2
B2F4K5
B3D0N1
B6D0U7
B9O1Q0
C0S7W0
C1G5Q0
C2L5P7
C2M6P8
C3I2K3
C4R8T7
C4S8W9
D0H9I4
D1L4Q9
D6H6O2
E3T0Z2
E4M2Q7
E4W8Z7
E5E6O8
E5F5V7
E9H2K7
E9L1W5
F0J2R8
F1F3I6
F4K3S1
F5I9Q2
F8N2S1
G0H4J3
G0I6P3
G1N6S7
G2J1M5
G2M7W4
G2O2S6
G4L8Z7
G8N2U5
G9L0O9
H0K3Q4
H1J5W8
H1M5Y6
H1W2Y1
H3U1Y1
H4I2T8
H6I0I6
H7I4J3
H7N4V9
I0J5U3
I2K2Y8
I6N3P3
J4J9W6
J9L6N9
K2S1U6
L1Q1Z5
L5Q6T7
M0P8U8
M4P7Q6
N1P8Q9
N7V9W9
O3R8Y5
P6U0Y0
P9S7W2
Q7V1Y5
W5Z4Z8
================================================
FILE: dataset/Training_C.csv
================================================
name
================================================
FILE: dataset/Training_D.csv
================================================
name
================================================
FILE: dataset/info.csv
================================================
External code,VendorName,Vendor,Centre,ED,ES
A0S9V9,Siemens,A,1,0,9
A1D0Q7,Philips,B,2,0,9
A1D9Z7,Siemens,A,1,22,11
A1E9Q1,Siemens,A,1,0,9
A1K2P5,Canon,D,5,33,11
A1O8Z3,Philips,B,3,23,10
A2C0I1,Siemens,A,1,0,7
A2E3W4,GE,C,4,0,0
A2H5K9,Philips,B,2,29,8
A2L1N6,Siemens,A,1,0,12
A2N8V0,Siemens,A,1,0,9
A3B7E5,Philips,B,2,29,12
A3H1O5,Siemens,A,1,0,12
A3H5R1,GE,C,4,24,6
A3P9V7,Canon,D,5,27,13
A4A8V9,Philips,B,3,0,10
A4B5U4,Siemens,A,1,0,10
A4B9O6,Canon,D,5,0,11
A4J4S4,Siemens,A,1,0,7
A4K8R4,Canon,D,5,24,11
A4R4T0,Canon,D,5,21,8
A4U9V5,Siemens,A,1,0,8
A5C2D2,Philips,B,3,24,9
A5D0G0,Philips,B,2,0,10
A5E0T8,Philips,B,3,24,8
A5H1Q2,GE,C,4,0,9
A5P5W0,Canon,D,5,33,10
A5Q1W8,Canon,D,5,27,10
A6A8H0,Canon,D,5,27,8
A6B5G9,Siemens,A,1,0,11
A6B7Y4,Canon,D,5,1,10
A6D5F9,Siemens,A,1,0,11
A6J0Y2,Philips,B,2,27,13
A6M1Q7,Philips,B,2,29,11
A7D9L8,Siemens,A,1,0,11
A7E4J0,Philips,B,3,1,12
A7F4G2,Canon,D,5,25,10
A7G0P5,Philips,B,2,28,9
A7M7P8,Siemens,A,1,0,9
A7O4T6,Siemens,A,1,0,10
A8C5E9,GE,C,4,24,9
A8C9H8,GE,C,4,0,0
A8C9U8,Philips,B,2,28,9
A8E1F4,Philips,B,3,24,9
A8I1U6,GE,C,4,0,0
A9C5P4,Philips,B,2,29,8
A9E3G9,Philips,B,3,23,8
A9F3T5,GE,C,4,24,8
A9J5Q7,Philips,B,3,24,7
A9J8W7,Philips,B,2,29,10
A9L7Y7,GE,C,4,24,5
B0H7V0,Siemens,A,1,0,9
B0I2Z0,Siemens,A,1,0,8
B0L3Y2,GE,C,4,24,10
B0N3W8,Philips,B,2,28,15
B1G9J3,Philips,B,2,29,10
B1K7U1,GE,C,4,0,0
B2C2Z7,Siemens,A,1,0,8
B2D9M2,Siemens,A,1,0,8
B2D9O2,Philips,B,2,29,13
B2F4K5,Philips,B,2,29,10
B2G5R2,Siemens,A,1,0,7
B2L0L2,GE,C,4,24,8
B3D0N1,Philips,B,3,24,8
B3E2W8,Canon,D,5,23,9
B3F0V9,Siemens,A,1,0,12
B3O1S0,Siemens,A,1,0,8
B3P3R1,Siemens,A,1,0,11
B3S2Z4,Philips,B,2,28,11
B4E1K1,Philips,B,3,0,9
B4I8Z7,GE,C,4,0,0
B4O3V3,Siemens,A,1,0,6
B4S1Y2,Philips,B,2,29,10
B5F8L9,Siemens,A,1,0,8
B5L5Y4,Philips,B,3,24,10
B5T6V0,Philips,B,3,0,10
B6D0U7,Philips,B,2,29,9
B6I0T4,Canon,D,5,33,10
B7F5P0,Philips,B,3,29,12
B8H5H6,Siemens,A,1,24,8
B8J7R4,Siemens,A,1,0,12
B8P5Q9,GE,C,4,23,9
B9E0Q1,Siemens,A,1,0,11
B9G4U2,Canon,D,5,28,8
B9H8N8,Siemens,A,1,0,10
B9O1Q0,Philips,B,2,29,11
C0K1P0,Siemens,A,1,0,9
C0L7V1,Canon,D,5,33,14
C0N8P4,GE,C,4,24,9
C0S7W0,Philips,B,3,24,7
C0W2Y5,GE,C,4,0,0
C1G5Q0,Philips,B,3,24,8
C1K8P5,Siemens,A,1,0,8
C2J0K3,Siemens,A,1,0,9
C2L5P7,Philips,B,2,28,12
C2M6P8,Philips,B,3,24,8
C3I2K3,Philips,B,2,29,11
C4E9I1,Siemens,A,1,0,8
C4K8M1,GE,C,4,0,0
C4R8T7,Philips,B,2,28,8
C4S8W9,Philips,B,3,24,8
C5L0R0,Canon,D,5,29,12
C5M4S2,Siemens,A,1,0,9
C5Q2Y5,GE,C,4,0,0
C6E0F9,Siemens,A,1,0,8
C6J5P1,Siemens,A,1,0,10
C6U4W8,Siemens,A,1,0,9
C7L8Z8,Philips,B,2,29,9
C7M6W0,GE,C,4,24,8
C8I7P7,Philips,B,2,29,11
C8J7L5,GE,C,4,24,7
C8O0P2,GE,C,4,24,8
C8P3S7,Siemens,A,1,0,8
C8V5W8,GE,C,4,0,0
D0E2W0,GE,C,4,0,0
D0H9I4,Philips,B,2,0,10
D0R0R9,Siemens,A,1,0,11
D1H2O9,GE,C,4,24,8
D1H6U2,Siemens,A,1,0,8
D1J5P6,Siemens,A,1,0,13
D1L4Q9,Philips,B,2,29,10
D1L6T4,Canon,D,5,27,10
D1M1S6,Siemens,A,1,0,10
D1R0Y5,Philips,B,3,24,9
D1S5T8,Canon,D,5,23,14
D2U0V0,GE,C,4,24,10
D3D4Y5,Siemens,A,1,0,9
D3F3O5,Siemens,A,1,24,10
D3F9H9,Siemens,A,1,24,8
D3K5Q2,GE,C,4,11,0
D3O9U9,Siemens,A,1,0,10
D3Q0W9,Philips,B,3,24,10
D4M3Q2,Siemens,A,1,0,9
D4N6W6,Siemens,A,1,0,10
D5G3W8,GE,C,4,23,8
D6E9U8,Philips,B,2,0,12
D6H6O2,Philips,B,2,29,9
D6N7Q8,GE,C,4,24,9
D7M8P9,GE,C,4,24,8
D7T3V8,GE,C,4,24,7
D8E4F4,Siemens,A,1,0,10
D8O0W2,Philips,B,3,24,8
D9F5P1,Siemens,A,1,0,10
D9I8O7,Philips,B,3,29,11
D9L1Z3,Siemens,A,1,0,12
E0J2Z9,Canon,D,5,25,11
E0J7L9,Siemens,A,1,0,9
E0M3U7,Siemens,A,1,0,9
E0O0S0,Siemens,A,1,0,11
E1L7M3,Canon,D,5,1,12
E1L8Y4,Philips,B,3,24,10
E3F2U7,GE,C,4,0,7
E3F5U2,GE,C,4,24,8
E3I4V1,GE,C,4,0,0
E3L8U8,Philips,B,3,29,9
E3T0Z2,Philips,B,2,29,10
E4H7L4,Canon,D,5,28,10
E4I9O7,Philips,B,3,29,10
E4M2Q7,Philips,B,3,0,8
E4O8P3,Philips,B,2,29,12
E4W8Z7,Philips,B,2,29,10
E5E6O8,Philips,B,2,29,10
E5F5V7,Philips,B,3,0,9
E5J6L2,Canon,D,5,29,11
E5S7W7,Philips,B,2,28,12
E6H0V9,Canon,D,5,33,12
E6J4N8,Philips,B,2,29,11
E6M6P2,GE,C,4,24,7
E7L0N6,GE,C,4,22,8
E9H1U4,Siemens,A,1,0,10
E9H2K7,Philips,B,2,29,11
E9L1W5,Philips,B,3,24,9
E9L4N2,GE,C,4,0,0
E9V4Z8,Siemens,A,1,0,8
E9V9Z2,GE,C,4,28,10
F0I6U8,GE,C,4,10,23
F0J2R8,Philips,B,2,29,9
F0K4T6,Siemens,A,1,0,8
F1F3I6,Philips,B,2,28,8
F1K2S9,GE,C,4,23,7
F2H5S1,Siemens,A,1,24,10
F3G5K5,Siemens,A,1,24,12
F3L0M1,GE,C,4,0,0
F4K3S1,Philips,B,3,24,8
F5I1Z8,GE,C,4,23,8
F5I9Q2,Philips,B,2,29,11
F6J9L9,Canon,D,5,29,10
F8N2S1,Philips,B,3,24,8
F9M1R2,Canon,D,5,29,11
G0H4J3,Philips,B,2,29,10
G0I6P3,Philips,B,3,29,12
G0I7Z6,GE,C,4,0,0
G1J5K3,Canon,D,5,24,13
G1K1V3,Philips,B,2,29,11
G1N6S7,Philips,B,2,0,12
G2J1M5,Philips,B,2,29,9
G2M7W4,Philips,B,3,24,9
G2O2S6,Philips,B,2,29,10
G3M5S4,Philips,B,3,29,10
G4I7V2,Canon,D,5,33,9
G4K8P3,GE,C,4,0,0
G4L8Z7,Philips,B,2,29,8
G4S9U3,Siemens,A,1,0,11
G4U3U8,GE,C,4,0,0
G5P4U3,Siemens,A,1,0,9
G6T0Z6,GE,C,4,0,0
G7I5V7,Siemens,A,1,0,10
G7N8R7,Philips,B,2,28,9
G7Q2W0,Philips,B,2,0,8
G7S6V0,Siemens,A,1,0,8
G8K0M3,Canon,D,5,25,10
G8L0Z0,Canon,D,5,3,10
G8N2U5,Philips,B,3,23,9
G8R0Z9,GE,C,4,21,8
G9L0O9,Philips,B,2,29,10
G9N5V9,Siemens,A,1,0,11
H0K3Q4,Philips,B,3,24,9
H1I3W0,Siemens,A,1,0,9
H1J5W8,Philips,B,2,3,15
H1M5Y6,Philips,B,2,29,8
H1N7P7,GE,C,4,0,0
H1N8S6,Siemens,A,1,0,10
H1W2Y1,Philips,B,2,29,10
H2M9S1,GE,C,4,23,8
H3R6S9,GE,C,4,24,9
H3U1Y1,Philips,B,3,24,8
H4I2T8,Philips,B,3,24,8
H5N0P0,Siemens,A,1,0,8
H6I0I6,Philips,B,2,28,10
H6P7T1,Canon,D,5,25,12
H7I4J3,Philips,B,3,24,8
H7K5U5,Philips,B,3,24,8
H7L8R8,GE,C,4,23,9
H7N4V9,Philips,B,2,0,12
H7P5Z4,GE,C,4,24,9
H8K2K7,Philips,B,3,23,7
H9J6L5,GE,C,4,0,0
I0I2J8,Philips,B,2,21,7
I0J5U3,Philips,B,2,28,10
I2J6Z6,Siemens,A,1,0,7
I2K2Y8,Philips,B,2,29,8
I4J8P4,GE,C,4,0,0
I4L4V7,Canon,D,5,22,9
I4R8V6,GE,C,4,24,8
I5L3S2,Siemens,A,1,0,10
I6N3P3,Philips,B,2,29,10
I6P4R0,GE,C,4,18,11
I6T4W8,Philips,B,2,28,12
I7T3U1,Siemens,A,1,0,13
I7W4Y8,Philips,B,2,27,10
I8N8Y1,Canon,D,5,33,14
I8Z0Z6,GE,C,4,24,8
J1T9Y1,Siemens,A,1,0,12
J2S5T1,GE,C,4,0,0
J4J8Q3,Siemens,A,1,0,8
J4J9W6,Philips,B,2,29,11
J6K4V3,GE,C,4,21,3
J6K6P5,Siemens,A,1,0,9
J6M5O2,Canon,D,5,33,11
J6P5T8,Siemens,A,1,0,9
J8R5W2,Siemens,A,1,0,9
J9L4S2,Philips,B,3,29,13
J9L6N9,Philips,B,2,29,9
K2S1U6,Philips,B,2,29,12
K3P3Y6,Canon,D,5,32,11
K3R0Y7,GE,C,4,28,9
K4T7Y0,Siemens,A,1,24,12
K5K6N1,Philips,B,3,24,9
K5L2U3,Siemens,A,1,0,7
K5L4S1,Canon,D,5,27,10
K5M7V5,Canon,D,5,23,10
K5P0Y1,Siemens,A,1,0,9
K6N4N7,Siemens,A,1,0,8
K7L2Y6,Philips,B,2,29,10
K7N0R7,Canon,D,5,29,9
K7O3Q0,GE,C,4,23,10
K9N0W0,Philips,B,2,29,12
L1Q1Z5,Philips,B,2,0,15
L1Q9V8,Siemens,A,1,0,10
L2V5Z0,GE,C,4,3,17
L4Q2U3,Siemens,A,1,0,8
L5Q6T7,Philips,B,2,28,11
L5U7Y4,Canon,D,5,33,10
L6T2T5,Canon,D,5,1,18
L7Y7Z2,Philips,B,2,0,11
L8M2U8,Canon,D,5,0,11
L8N7P0,GE,C,4,23,9
L8N7Z0,Philips,B,3,29,9
M0P8U8,Philips,B,3,0,9
M1R4S1,Siemens,A,1,0,9
M2P1R1,Siemens,A,1,0,10
M2P5T8,Canon,D,5,0,11
M4P7Q6,Philips,B,2,2,13
M4T4V6,Philips,B,2,29,12
M6M9N1,GE,C,4,24,11
M6V2Y0,Philips,B,3,0,9
M8N3Z3,GE,C,4,0,0
N1P8Q9,Philips,B,2,29,10
N1S7Z2,GE,C,4,0,0
N2O7U5,Canon,D,5,18,7
N5S7Y1,Siemens,A,1,0,9
N7P3T8,Canon,D,5,25,10
N7V9W9,Philips,B,3,24,7
N7W6Z8,Canon,D,5,23,11
N8N9U0,Siemens,A,1,0,11
N9P5Z0,Philips,B,3,0,13
N9Q4T8,Canon,D,5,32,10
O0S9V7,Siemens,A,1,0,9
O1O9Y6,GE,C,4,0,0
O3R8Y5,Philips,B,2,0,8
O4O6U5,GE,C,4,24,8
O4T6Y7,Philips,B,2,0,13
O5U2U7,Canon,D,5,1,9
O7Q7U3,Canon,D,5,29,11
O9V8W5,Philips,B,2,29,12
P0S5Y0,Siemens,A,1,0,10
P3P9S5,Siemens,A,1,0,10
P3R6Y5,Philips,B,3,24,9
P3T5U1,GE,C,4,27,9
P5R1Y4,Siemens,A,1,0,9
P6U0Y0,Philips,B,2,29,10
P8V0Y7,Canon,D,5,27,9
P8W4Z0,Philips,B,3,0,12
P9S7W2,Philips,B,3,23,10
Q0Q1Y4,GE,C,4,0,0
Q0U0V5,Siemens,A,1,0,10
Q1Q3T1,GE,C,4,28,10
Q3Q6R8,Philips,B,3,24,9
Q3R9W7,Siemens,A,1,0,10
Q4W5Z8,Canon,D,5,33,10
Q5V8W3,GE,C,4,24,10
Q7V1Y5,Philips,B,2,0,11
R1R6Y8,GE,C,4,23,10
R2R7Z5,GE,C,4,23,8
R3V5W7,Canon,D,5,27,13
R4Y1Z9,Siemens,A,1,24,7
R6V5W3,GE,C,4,28,10
R8V0Y4,GE,C,4,24,7
S1S3Z7,Siemens,A,1,0,9
T2T9Z9,Siemens,A,1,0,13
T2Z1Z9,Canon,D,5,29,9
T9U9W2,Siemens,A,1,0,10
V4W8Z5,GE,C,4,19,9
W5Z4Z8,Philips,B,2,29,11
Y6Y9Z2,Philips,B,3,29,9
================================================
FILE: dataset_domain.py
================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import SimpleITK as sitk
from skimage.measure import label, regionprops
import math
import pdb
class CMRDataset(Dataset):
def __init__(self, dataset_dir, mode='train', domain='A', crop_size=256, scale=0.1, rotate=10, debug=False):
self.mode = mode
self.dataset_dir = dataset_dir
self.crop_size = crop_size
self.scale = scale
self.rotate = rotate
if self.mode == 'train':
pre_face = 'Training'
if 'C' in domain or 'D' in domain:
print('No domain C or D in Training set')
raise StandardError
elif self.mode == 'test':
pre_face = 'Testing'
else:
print('Wrong mode')
raise StandardError
if debug:
# validation set is the smallest, need the shortest time for load data.
pre_face = 'Testing'
path = self.dataset_dir + pre_face + '/'
print('start loading data')
name_list = []
if 'A' in domain:
df = pd.read_csv(self.dataset_dir+pre_face+'_A.csv')
name_list += np.array(df['name']).tolist()
if 'B' in domain:
df = pd.read_csv(self.dataset_dir+pre_face+'_B.csv')
name_list += np.array(df['name']).tolist()
if 'C' in domain:
df = pd.read_csv(self.dataset_dir+pre_face+'_C.csv')
name_list += np.array(df['name']).tolist()
if 'D' in domain:
df = pd.read_csv(self.dataset_dir+pre_face+'_D.csv')
name_list += np.array(df['name']).tolist()
img_list = []
lab_list = []
spacing_list = []
for name in name_list:
for name_idx in os.listdir(path+name):
if 'gt' in name_idx:
continue
else:
idx = name_idx.split('_')[2].split('.')[0]
itk_img = sitk.ReadImage(path+name+'/%s_sa_%s.nii.gz'%(name, idx))
itk_lab = sitk.ReadImage(path+name+'/%s_sa_gt_%s.nii.gz'%(name, idx))
spacing = np.array(itk_lab.GetSpacing()).tolist()
spacing_list.append(spacing[::-1])
assert itk_img.GetSize() == itk_lab.GetSize()
img, lab = self.preprocess(itk_img, itk_lab)
img_list.append(img)
lab_list.append(lab)
self.img_slice_list = []
self.lab_slice_list = []
if self.mode == 'train':
for i in range(len(img_list)):
tmp_img = img_list[i]
tmp_lab = lab_list[i]
z, x, y = tmp_img.shape
for j in range(z):
self.img_slice_list.append(tmp_img[j])
self.lab_slice_list.append(tmp_lab[j])
else:
self.img_slice_list = img_list
self.lab_slice_list = lab_list
self.spacing_list = spacing_list
print('load done, length of dataset:', len(self.img_slice_list))
def __len__(self):
return len(self.img_slice_list)
def preprocess(self, itk_img, itk_lab):
img = sitk.GetArrayFromImage(itk_img)
lab = sitk.GetArrayFromImage(itk_lab)
max98 = np.percentile(img, 98)
img = np.clip(img, 0, max98)
z, y, x = img.shape
if x < self.crop_size:
diff = (self.crop_size + 10 - x) // 2
img = np.pad(img, ((0,0), (0,0), (diff, diff)))
lab = np.pad(lab, ((0,0), (0,0), (diff,diff)))
if y < self.crop_size:
diff = (self.crop_size + 10 -y) // 2
img = np.pad(img, ((0,0), (diff, diff), (0,0)))
lab = np.pad(lab, ((0,0), (diff, diff), (0,0)))
img = img / max98
tensor_img = torch.from_numpy(img).float()
tensor_lab = torch.from_numpy(lab).long()
return tensor_img, tensor_lab
def __getitem__(self, idx):
tensor_image = self.img_slice_list[idx]
tensor_label = self.lab_slice_list[idx]
if self.mode == 'train':
tensor_image = tensor_image.unsqueeze(0).unsqueeze(0)
tensor_label = tensor_label.unsqueeze(0).unsqueeze(0)
# Gaussian Noise
tensor_image += torch.randn(tensor_image.shape) * 0.02
# Additive brightness
rnd_bn = np.random.normal(0, 0.7)#0.03
tensor_image += rnd_bn
# gamma
minm = tensor_image.min()
rng = tensor_image.max() - minm
gamma = np.random.uniform(0.5, 1.6)
tensor_image = torch.pow((tensor_image-minm)/rng, gamma)*rng + minm
tensor_image, tensor_label = self.random_zoom_rotate(tensor_image, tensor_label)
tensor_image, tensor_label = self.randcrop(tensor_image, tensor_label)
else:
tensor_image, tensor_label = self.center_crop(tensor_image, tensor_label)
assert tensor_image.shape == tensor_label.shape
if self.mode == 'train':
return tensor_image, tensor_label
else:
return tensor_image, tensor_label, np.array(self.spacing_list[idx])
def randcrop(self, img, label):
_, _, H, W = img.shape
diff_H = H - self.crop_size
diff_W = W - self.crop_size
rand_x = np.random.randint(0, diff_H)
rand_y = np.random.randint(0, diff_W)
croped_img = img[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
croped_lab = label[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
return croped_img, croped_lab
def center_crop(self, img, label):
D, H, W = img.shape
diff_H = H - self.crop_size
diff_W = W - self.crop_size
rand_x = diff_H // 2
rand_y = diff_W // 2
croped_img = img[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
croped_lab = label[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
return croped_img, croped_lab
def random_zoom_rotate(self, img, label):
scale_x = np.random.random() * 2 * self.scale + (1 - self.scale)
scale_y = np.random.random() * 2 * self.scale + (1 - self.scale)
theta_scale = torch.tensor([[scale_x, 0, 0],
[0, scale_y, 0],
[0, 0, 1]]).float()
angle = (float(np.random.randint(-self.rotate, self.rotate)) / 180.) * math.pi
theta_rotate = torch.tensor( [ [math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
]).float()
theta_rotate = theta_rotate.unsqueeze(0)
grid = F.affine_grid(theta_rotate, img.size())
img = F.grid_sample(img, grid, mode='bilinear')
label = F.grid_sample(label.float(), grid, mode='nearest').long()
return img, label
================================================
FILE: losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb
class DiceLoss(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):
super(DiceLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.size_average = size_average
self.reduce = reduce
def forward(self, preds, targets):
N = preds.size(0)
C = preds.size(1)
P = F.softmax(preds, dim=1)
smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001)
class_mask = torch.zeros(preds.shape).to(preds.device)
class_mask.scatter_(1, targets, 1.)
ones = torch.ones(preds.shape).to(preds.device)
P_ = ones - P
class_mask_ = ones - class_mask
TP = P * class_mask
FP = P * class_mask_
FN = P_ * class_mask
smooth = smooth.to(preds.device)
self.alpha = FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) / ((FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) + FN.transpose(0, 1).reshape(C, -1).sum(dim=(1))) + smooth)
self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8)
#print('alpha:', self.alpha)
self.beta = 1 - self.alpha
num = torch.sum(TP.transpose(0, 1).reshape(C, -1), dim=(1)).float()
den = num + self.alpha * torch.sum(FP.transpose(0, 1).reshape(C, -1), dim=(1)).float() + self.beta * torch.sum(FN.transpose(0, 1).reshape(C, -1), dim=(1)).float()
dice = num / (den + smooth)
if not self.reduce:
loss = torch.ones(C).to(dice.device) - dice
return loss
loss = 1 - dice
loss = loss.sum()
if self.size_average:
loss /= C
return loss
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = torch.ones(class_num)
else:
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
def forward(self, preds, targets):
N = preds.size(0)
C = preds.size(1)
targets = targets.unsqueeze(1)
P = F.softmax(preds, dim=1)
log_P = F.log_softmax(preds, dim=1)
class_mask = torch.zeros(preds.shape).to(preds.device)
class_mask.scatter_(1, targets, 1.)
if targets.size(1) == 1:
# squeeze the chaneel for target
targets = targets.squeeze(1)
alpha = self.alpha[targets.data].to(preds.device)
probs = (P * class_mask).sum(1)
log_probs = (log_P * class_mask).sum(1)
batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
if __name__ == '__main__':
DL = DiceLoss()
FL = FocalLoss(10)
pred = torch.randn(2, 10, 128, 128)
target = torch.zeros((2, 1, 128, 128)).long()
dl_loss = DL(pred, target)
fl_loss = FL(pred, target)
print('2D:', dl_loss.item(), fl_loss.item())
pred = torch.randn(2, 10, 64, 128, 128)
target = torch.zeros(2, 1, 64, 128, 128).long()
dl_loss = DL(pred, target)
fl_loss = FL(pred, target)
print('3D:', dl_loss.item(), fl_loss.item())
================================================
FILE: model/__init__.py
================================================
================================================
FILE: model/conv_trans_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import pdb
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
class depthwise_separable_conv(nn.Module):
def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False):
super().__init__()
self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride)
self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class Mlp(nn.Module):
def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_ch = out_ch or in_ch
hid_ch = hid_ch or in_ch
self.fc1 = nn.Conv2d(in_ch, hid_ch, kernel_size=1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hid_ch, out_ch, kernel_size=1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or inplanes != planes:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(inplanes),
self.relu,
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
residue = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out += self.shortcut(residue)
return out
class BasicTransBlock(nn.Module):
def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_ch)
self.attn = LinearAttention(in_ch, heads=heads, dim_head=in_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
self.bn2 = nn.BatchNorm2d(in_ch)
self.relu = nn.ReLU(inplace=True)
self.mlp = nn.Conv2d(in_ch, in_ch, kernel_size=1, bias=False)
# conv1x1 has not difference with mlp in performance
def forward(self, x):
out = self.bn1(x)
out, q_k_attn = self.attn(out)
out = out + x
residue = out
out = self.bn2(out)
out = self.relu(out)
out = self.mlp(out)
out += residue
return out
class BasicTransDecoderBlock(nn.Module):
def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
self.bn_l = nn.BatchNorm2d(in_ch)
self.bn_h = nn.BatchNorm2d(out_ch)
self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)
self.attn = LinearAttentionDecoder(in_ch, out_ch, heads=heads, dim_head=out_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
self.bn2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
self.mlp = nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)
def forward(self, x1, x2):
residue = F.interpolate(self.conv_ch(x1), size=x2.shape[-2:], mode='bilinear', align_corners=True)
#x1: low-res, x2: high-res
x1 = self.bn_l(x1)
x2 = self.bn_h(x2)
out, q_k_attn = self.attn(x2, x1)
out = out + residue
residue = out
out = self.bn2(out)
out = self.relu(out)
out = self.mlp(out)
out += residue
return out
########################################################################
# Transformer components
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** (-0.5)
self.dim_head = dim_head
self.reduce_size = reduce_size
self.projection = projection
self.rel_pos = rel_pos
# depthwise conv is slightly better than conv1x1
#self.to_qkv = nn.Conv2d(dim, self.inner_dim*3, kernel_size=1, stride=1, padding=0, bias=True)
#self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)
self.to_qkv = depthwise_separable_conv(dim, self.inner_dim*3)
self.to_out = depthwise_separable_conv(self.inner_dim, dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
if self.rel_pos:
# 2D input-independent relative position encoding is a little bit better than
# 1D input-denpendent counterpart
self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)
#self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)
def forward(self, x):
B, C, H, W = x.shape
#B, inner_dim, H, W
qkv = self.to_qkv(x)
q, k, v = qkv.chunk(3, dim=1)
if self.projection == 'interp' and H != self.reduce_size:
k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))
elif self.projection == 'maxpool' and H != self.reduce_size:
k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))
q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=H, w=W)
k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))
q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)
if self.rel_pos:
relative_position_bias = self.relative_position_encoding(H, W)
q_k_attn += relative_position_bias
#rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, H, W, self.dim_head)
#q_k_attn = q_k_attn + rel_attn_h + rel_attn_w
q_k_attn *= self.scale
q_k_attn = F.softmax(q_k_attn, dim=-1)
q_k_attn = self.attn_drop(q_k_attn)
out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)
out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads)
out = self.to_out(out)
out = self.proj_drop(out)
return out, q_k_attn
class LinearAttentionDecoder(nn.Module):
def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** (-0.5)
self.dim_head = dim_head
self.reduce_size = reduce_size
self.projection = projection
self.rel_pos = rel_pos
# depthwise conv is slightly better than conv1x1
#self.to_kv = nn.Conv2d(dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=True)
#self.to_q = nn.Conv2d(dim, self.inner_dim, kernel_size=1, stride=1, padding=0, bias=True)
#self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)
self.to_kv = depthwise_separable_conv(in_dim, self.inner_dim*2)
self.to_q = depthwise_separable_conv(out_dim, self.inner_dim)
self.to_out = depthwise_separable_conv(self.inner_dim, out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
if self.rel_pos:
self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)
#self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)
def forward(self, q, x):
B, C, H, W = x.shape # low-res feature shape
BH, CH, HH, WH = q.shape # high-res feature shape
k, v = self.to_kv(x).chunk(2, dim=1) #B, inner_dim, H, W
q = self.to_q(q) #BH, inner_dim, HH, WH
if self.projection == 'interp' and H != self.reduce_size:
k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))
elif self.projection == 'maxpool' and H != self.reduce_size:
k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))
q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=HH, w=WH)
k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))
q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)
if self.rel_pos:
relative_position_bias = self.relative_position_encoding(HH, WH)
q_k_attn += relative_position_bias
#rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, HH, WH, self.dim_head)
#q_k_attn = q_k_attn + rel_attn_h + rel_attn_w
q_k_attn *= self.scale
q_k_attn = F.softmax(q_k_attn, dim=-1)
q_k_attn = self.attn_drop(q_k_attn)
out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)
out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=HH, w=WH, dim_head=self.dim_head, heads=self.heads)
out = self.to_out(out)
out = self.proj_drop(out)
return out, q_k_attn
class RelativePositionEmbedding(nn.Module):
# input-dependent relative position
def __init__(self, dim, shape):
super().__init__()
self.dim = dim
self.shape = shape
self.key_rel_w = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)
self.key_rel_h = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)
coords = torch.arange(self.shape)
relative_coords = coords[None, :] - coords[:, None] # h, h
relative_coords += self.shape - 1 # shift to start from 0
self.register_buffer('relative_position_index', relative_coords)
def forward(self, q, Nh, H, W, dim_head):
# q: B, Nh, HW, dim
B, _, _, dim = q.shape
# q: B, Nh, H, W, dim_head
q = rearrange(q, 'b heads (h w) dim_head -> b heads h w dim_head', b=B, dim_head=dim_head, heads=Nh, h=H, w=W)
rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, 'w')
rel_logits_h = self.relative_logits_1d(q.permute(0, 1, 3, 2, 4), self.key_rel_h, 'h')
return rel_logits_w, rel_logits_h
def relative_logits_1d(self, q, rel_k, case):
B, Nh, H, W, dim = q.shape
rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) # B, Nh, H, W, 2*shape-1
if W != self.shape:
# self_relative_position_index origin shape: w, w
# after repeat: W, w
relative_index= torch.repeat_interleave(self.relative_position_index, W//self.shape, dim=0) # W, shape
relative_index = relative_index.view(1, 1, 1, W, self.shape)
relative_index = relative_index.repeat(B, Nh, H, 1, 1)
rel_logits = torch.gather(rel_logits, 4, relative_index) # B, Nh, H, W, shape
rel_logits = rel_logits.unsqueeze(3)
rel_logits = rel_logits.repeat(1, 1, 1, self.shape, 1, 1)
if case == 'w':
rel_logits = rearrange(rel_logits, 'b heads H h W w -> b heads (H W) (h w)')
elif case == 'h':
rel_logits = rearrange(rel_logits, 'b heads W w H h -> b heads (H W) (h w)')
return rel_logits
class RelativePositionBias(nn.Module):
# input-independent relative position attention
# As the number of parameters is smaller, so use 2D here
# Borrowed some code from SwinTransformer: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
def __init__(self, num_heads, h, w):
super().__init__()
self.num_heads = num_heads
self.h = h
self.w = w
self.relative_position_bias_table = nn.Parameter(
torch.randn((2*h-1) * (2*w-1), num_heads)*0.02)
coords_h = torch.arange(self.h)
coords_w = torch.arange(self.w)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w
coords_flatten = torch.flatten(coords, 1) # 2, hw
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.h - 1
relative_coords[:, :, 1] += self.w - 1
relative_coords[:, :, 0] *= 2 * self.h - 1
relative_position_index = relative_coords.sum(-1) # hw, hw
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, H, W):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.h, self.w, self.h*self.w, -1) #h, w, hw, nH
relative_position_bias_expand_h = torch.repeat_interleave(relative_position_bias, H//self.h, dim=0)
relative_position_bias_expanded = torch.repeat_interleave(relative_position_bias_expand_h, W//self.w, dim=1) #HW, hw, nH
relative_position_bias_expanded = relative_position_bias_expanded.view(H*W, self.h*self.w, self.num_heads).permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias_expanded
###########################################################################
# Unet Transformer building block
class down_block_trans(nn.Module):
def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool=True, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
block_list = []
if bottleneck:
block = BottleneckBlock
else:
block = BasicBlock
attn_block = BasicTransBlock
if maxpool:
block_list.append(nn.MaxPool2d(2))
block_list.append(block(in_ch, out_ch, stride=1))
else:
block_list.append(block(in_ch, out_ch, stride=2))
assert num_block > 0
for i in range(num_block):
block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
self.blocks = nn.Sequential(*block_list)
def forward(self, x):
out = self.blocks(x)
return out
class up_block_trans(nn.Module):
def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
self.attn_decoder = BasicTransDecoderBlock(in_ch, out_ch, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
if bottleneck:
block = BottleneckBlock
else:
block = BasicBlock
attn_block = BasicTransBlock
block_list = []
for i in range(num_block):
block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
block_list.append(block(2*out_ch, out_ch, stride=1))
self.blocks = nn.Sequential(*block_list)
def forward(self, x1, x2):
# x1: low-res feature, x2: high-res feature
out = self.attn_decoder(x1, x2)
out = torch.cat([out, x2], dim=1)
out = self.blocks(out)
return out
class block_trans(nn.Module):
def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
super().__init__()
block_list = []
attn_block = BasicTransBlock
assert num_block > 0
for i in range(num_block):
block_list.append(attn_block(in_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
self.blocks = nn.Sequential(*block_list)
def forward(self, x):
out = self.blocks(x)
return out
================================================
FILE: model/resnet_utnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_utils import up_block
from .transunet import ResNetV2
from .conv_trans_utils import block_trans
import pdb
class ResNet_UTNet(nn.Module):
def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', num_blocks=[1,2,4], projection='interp', num_heads=[4,4,4], attn_drop=0., proj_drop=0., rel_pos=True, block_units=(3,4,9), width_factor=1):
super().__init__()
self.resnet = ResNetV2(block_units, width_factor)
if '0' in block_list:
self.trans_0 = block_trans(64, num_blocks[-4], 64//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.trans_0 = nn.Identity()
if '1' in block_list:
self.trans_1 = block_trans(256, num_blocks[-3], 256//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.trans_1 = nn.Identity()
if '2' in block_list:
self.trans_2 = block_trans(512, num_blocks[-2], 512//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.trans_2 = nn.Identity()
if '3' in block_list:
self.trans_3 = block_trans(1024, num_blocks[-1], 1024//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.trans_3 = nn.Identity()
self.up1 = up_block(1024, 512, scale=(2,2), num_block=1)
self.up2 = up_block(512, 256, scale=(2,2), num_block=1)
self.up3 = up_block(256, 64, scale=(2,2), num_block=1)
self.up4 = nn.UpsamplingBilinear2d(scale_factor=2)
self.output = nn.Conv2d(64, num_class, kernel_size=3, padding=1, bias=True)
def forward(self, x):
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
x, features = self.resnet(x)
out3 = self.trans_3(x)
out2 = self.trans_2(features[0])
out1 = self.trans_1(features[1])
out0 = self.trans_0(features[2])
out = self.up1(out3, out2)
out = self.up2(out, out1)
out = self.up3(out, out0)
out = self.up4(out)
out = self.output(out)
return out
================================================
FILE: model/swin_unet.py
================================================
# code is borrowed from the original repo and fit into our training framework
# https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import copy
import logging
import math
from os.path import join as pjoin
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class PatchExpand(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
x = x.view(B,-1,C//4)
x= self.norm(x)
return x
class FinalPatchExpand_X4(nn.Module):
def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.dim_scale = dim_scale
self.expand = nn.Linear(dim, 16*dim, bias=False)
self.output_dim = dim
self.norm = norm_layer(self.output_dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
x = x.view(B,-1,self.output_dim)
x= self.norm(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class BasicLayer_up(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if upsample is not None:
self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
else:
self.upsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformerSys(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, final_upsample="expand_first", **kwargs):
super().__init__()
print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
depths_decoder,drop_path_rate,num_classes))
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.num_features_up = int(embed_dim * 2)
self.mlp_ratio = mlp_ratio
self.final_upsample = final_upsample
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build encoder and bottleneck layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
# build decoder layers
self.layers_up = nn.ModuleList()
self.concat_back_dim = nn.ModuleList()
for i_layer in range(self.num_layers):
concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
if i_layer ==0 :
layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
else:
layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
depth=depths[(self.num_layers-1-i_layer)],
num_heads=num_heads[(self.num_layers-1-i_layer)],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
norm_layer=norm_layer,
upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers_up.append(layer_up)
self.concat_back_dim.append(concat_linear)
self.norm = norm_layer(self.num_features)
self.norm_up= norm_layer(self.embed_dim)
if self.final_upsample == "expand_first":
print("---final upsample expand_first---")
self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)
self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
#Encoder and Bottleneck
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
x_downsample = []
for layer in self.layers:
x_downsample.append(x)
x = layer(x)
x = self.norm(x) # B L C
return x, x_downsample
#Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
x = torch.cat([x,x_downsample[3-inx]],-1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
x = self.norm_up(x) # B L C
return x
def up_x4(self, x):
H, W = self.patches_resolution
B, L, C = x.shape
assert L == H*W, "input features has wrong size"
if self.final_upsample=="expand_first":
x = self.up(x)
x = x.view(B,4*H,4*W,-1)
x = x.permute(0,3,1,2) #B,C,H,W
x = self.output(x)
return x
def forward(self, x):
x, x_downsample = self.forward_features(x)
x = self.forward_up_features(x,x_downsample)
x = self.up_x4(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
logger = logging.getLogger(__name__)
class SwinUnet_config():
def __init__(self):
self.patch_size = 4
self.in_chans = 3
self.num_classes = 4
self.embed_dim = 96
self.depths = [2, 2, 6, 2]
self.num_heads = [3, 6, 12, 24]
self.window_size = 7
self.mlp_ratio = 4.
self.qkv_bias = True
self.qk_scale = None
self.drop_rate = 0.
self.drop_path_rate = 0.1
self.ape = False
self.patch_norm = True
self.use_checkpoint = False
class SwinUnet(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(SwinUnet, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.config = config
self.swin_unet = SwinTransformerSys(img_size=img_size,
patch_size=config.patch_size,
in_chans=config.in_chans,
num_classes=self.num_classes,
embed_dim=config.embed_dim,
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
qk_scale=config.qk_scale,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
ape=config.ape,
patch_norm=config.patch_norm,
use_checkpoint=config.use_checkpoint)
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
logits = self.swin_unet(x)
return logits
def load_from(self, pretrained_path):
if pretrained_path is not None:
print("pretrained_path:{}".format(pretrained_path))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(pretrained_path, map_location=device)
if "model" not in pretrained_dict:
print("---start load pretrained modle by splitting---")
pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
for k in list(pretrained_dict.keys()):
if "output" in k:
print("delete key:{}".format(k))
del pretrained_dict[k]
msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
# print(msg)
return
pretrained_dict = pretrained_dict['model']
print("---start load pretrained modle of swin encoder---")
model_dict = self.swin_unet.state_dict()
full_dict = copy.deepcopy(pretrained_dict)
for k, v in pretrained_dict.items():
if "layers." in k:
current_layer_num = 3-int(k[7:8])
current_k = "layers_up." + str(current_layer_num) + k[8:]
full_dict.update({current_k:v})
for k in list(full_dict.keys()):
if k in model_dict:
if full_dict[k].shape != model_dict[k].shape:
print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))
del full_dict[k]
msg = self.swin_unet.load_state_dict(full_dict, strict=False)
# print(msg)
else:
print("none pretrain")
================================================
FILE: model/transunet.py
================================================
# The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660ecda1d59
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ml_collections
import copy
import logging
import math
from collections import OrderedDict
from os.path import join as pjoin
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
logger = logging.getLogger(__name__)
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer["num_heads"]
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = Linear(config.hidden_size, self.all_head_size)
self.key = Linear(config.hidden_size, self.all_head_size)
self.value = Linear(config.hidden_size, self.all_head_size)
self.out = Linear(config.hidden_size, config.hidden_size)
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config.transformer["dropout_rate"])
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
self.config = config
img_size = _pair(img_size)
if config.patches.get("grid") is not None: # ResNet
grid_size = config.patches["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
self.hybrid = True
else:
patch_size = _pair(config.patches["size"])
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.hybrid = False
if self.hybrid:
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
in_channels = self.hybrid_model.width * 16
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
if self.hybrid:
x, features = self.hybrid_model(x)
else:
features = None
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
x = x.flatten(2)
x = x.transpose(-1, -2) # (B, n_patches, hidden)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, features
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights
def load_from(self, weights, n_block):
ROOT = f"Transformer/encoderblock_{n_block}"
with torch.no_grad():
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
self.attn.query.weight.copy_(query_weight)
self.attn.key.weight.copy_(key_weight)
self.attn.value.weight.copy_(value_weight)
self.attn.out.weight.copy_(out_weight)
self.attn.query.bias.copy_(query_bias)
self.attn.key.bias.copy_(key_bias)
self.attn.value.bias.copy_(value_bias)
self.attn.out.bias.copy_(out_bias)
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
self.ffn.fc1.weight.copy_(mlp_weight_0)
self.ffn.fc2.weight.copy_(mlp_weight_1)
self.ffn.fc1.bias.copy_(mlp_bias_0)
self.ffn.fc2.bias.copy_(mlp_bias_1)
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))
def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_ids):
embedding_output, features = self.embeddings(input_ids)
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
return encoded, attn_weights, features
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d(out_channels)
super(Conv2dReLU, self).__init__(conv, bn, relu)
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)
class DecoderCup(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
head_channels = 512
self.conv_more = Conv2dReLU(
config.hidden_size,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = config.decoder_channels
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
if self.config.n_skip != 0:
skip_channels = self.config.skip_channels
for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
skip_channels[3-i]=0
else:
skip_channels=[0,0,0,0]
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, hidden_states, features=None):
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.config.n_skip) else None
else:
skip = None
x = decoder_block(x, skip=skip)
return x
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config)
self.segmentation_head = SegmentationHead(
in_channels=config['decoder_channels'][-1],
out_channels=config['n_classes'],
kernel_size=3,
)
self.config = config
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)
logits = self.segmentation_head(x)
return logits
def load_from(self, weights):
with torch.no_grad():
res_weight = weights
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.transformer.embeddings.position_embeddings.copy_(posemb)
elif posemb.size()[1]-1 == posemb_new.size()[1]:
posemb = posemb[:, 1:]
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "seg":
_, posemb_grid = posemb[:, :1], posemb[0, 1:]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = posemb_grid
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
# Encoder whole
for bname, block in self.transformer.encoder.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)
if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(res_weight, n_block=bname, n_unit=uname)
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'seg'
config.representation_size = None
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
config.patch_size = 16
config.skip_channels = [0, 0, 0, 0]
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config
def get_testing():
"""Returns a minimal configuration for testing."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1
config.transformer.num_heads = 1
config.transformer.num_layers = 1
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
config.classifier = 'seg'
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.n_skip = 3
config.activation = 'softmax'
return config
def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
config.patches.size = (32, 32)
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
return config
def get_l16_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1024
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 4096
config.transformer.num_heads = 16
config.transformer.num_layers = 24
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.representation_size = None
# custom
config.classifier = 'seg'
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config
def get_r50_l16_config():
"""Returns the Resnet50 + ViT-L/16 configuration. customized """
config = get_l16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1
config.classifier = 'seg'
config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.activation = 'softmax'
return config
def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.patches.size = (32, 32)
return config
def get_h14_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 1280
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 5120
config.transformer.num_heads = 16
config.transformer.num_layers = 32
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
CONFIGS = {
'ViT-B_16': get_b16_config(),
'ViT-B_32': get_b32_config(),
'ViT-L_16': get_l16_config(),
'ViT-L_32': get_l32_config(),
'ViT-H_14': get_h14_config(),
'R50-ViT-B_16': get_r50_b16_config(),
'R50-ViT-L_16': get_r50_l16_config(),
'testing': get_testing(),
}
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-5)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
padding=1, bias=bias, groups=groups)
def conv1x1(cin, cout, stride=1, bias=False):
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
padding=0, bias=bias)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
"""
def __init__(self, cin, cout=None, cmid=None, stride=1):
super().__init__()
cout = cout or cin
cmid = cmid or cout//4
self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
self.conv1 = conv1x1(cin, cmid, bias=False)
self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
self.conv3 = conv1x1(cmid, cout, bias=False)
self.relu = nn.ReLU(inplace=True)
if (stride != 1 or cin != cout):
# Projection also with pre-activation according to paper.
self.downsample = conv1x1(cin, cout, stride, bias=False)
self.gn_proj = nn.GroupNorm(cout, cout)
def forward(self, x):
# Residual branch
residual = x
if hasattr(self, 'downsample'):
residual = self.downsample(x)
residual = self.gn_proj(residual)
# Unit's branch
y = self.relu(self.gn1(self.conv1(x)))
y = self.relu(self.gn2(self.conv2(y)))
y = self.gn3(self.conv3(y))
y = self.relu(residual + y)
return y
def load_from(self, weights, n_block, n_unit):
conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
self.conv1.weight.copy_(conv1_weight)
self.conv2.weight.copy_(conv2_weight)
self.conv3.weight.copy_(conv3_weight)
self.gn1.weight.copy_(gn1_weight.view(-1))
self.gn1.bias.copy_(gn1_bias.view(-1))
self.gn2.weight.copy_(gn2_weight.view(-1))
self.gn2.bias.copy_(gn2_bias.view(-1))
self.gn3.weight.copy_(gn3_weight.view(-1))
self.gn3.bias.copy_(gn3_bias.view(-1))
if hasattr(self, 'downsample'):
proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
self.downsample.weight.copy_(proj_conv_weight)
self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode."""
def __init__(self, block_units, width_factor):
super().__init__()
width = int(64 * width_factor)
self.width = width
self.root = nn.Sequential(OrderedDict([
('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
('gn', nn.GroupNorm(32, width, eps=1e-6)),
('relu', nn.ReLU(inplace=True)),
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
]))
self.body = nn.Sequential(OrderedDict([
('block1', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
))),
('block2', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
))),
('block3', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
))),
]))
def forward(self, x):
features = []
b, c, in_size, _ = x.size()
x = self.root(x)
features.append(x)
x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
for i in range(len(self.body)-1):
x = self.body[i](x)
right_size = int(in_size / 4 / (i+1))
if x.size()[2] != right_size:
pad = right_size - x.size()[2]
assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
else:
feat = x
features.append(feat)
x = self.body[-1](x)
return x, features[::-1]
================================================
FILE: model/unet_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or inplanes != planes:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(inplanes),
self.relu,
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
residue = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out += self.shortcut(residue)
return out
class BottleneckBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = conv1x1(inplanes, planes//4, stride=1)
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes//4, planes//4, stride=stride)
self.bn2 = nn.BatchNorm2d(planes//4)
self.conv3 = conv1x1(planes//4, planes, stride=1)
self.bn3 = nn.BatchNorm2d(planes//4)
self.shortcut = nn.Sequential()
if stride != 1 or inplanes != planes:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(inplanes),
self.relu,
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
)
def forward(self, x):
residue = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += self.shortcut(residue)
return out
class inconv(nn.Module):
def __init__(self, in_ch, out_ch, bottleneck=False):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
if bottleneck:
self.conv2 = BottleneckBlock(out_ch, out_ch)
else:
self.conv2 = BasicBlock(out_ch, out_ch)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
return out
class down_block(nn.Module):
def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, pool=True):
super().__init__()
block_list = []
if bottleneck:
block = BottleneckBlock
else:
block = BasicBlock
if pool:
block_list.append(nn.MaxPool2d(scale))
block_list.append(block(in_ch, out_ch))
else:
block_list.append(block(in_ch, out_ch, stride=2))
for i in range(num_block-1):
block_list.append(block(out_ch, out_ch, stride=1))
self.conv = nn.Sequential(*block_list)
def forward(self, x):
return self.conv(x)
class up_block(nn.Module):
def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=False):
super().__init__()
self.scale=scale
self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)
if bottleneck:
block = BottleneckBlock
else:
block = BasicBlock
block_list = []
block_list.append(block(2*out_ch, out_ch))
for i in range(num_block-1):
block_list.append(block(out_ch, out_ch))
self.conv = nn.Sequential(*block_list)
def forward(self, x1, x2):
x1 = F.interpolate(x1, scale_factor=self.scale, mode='bilinear', align_corners=True)
x1 = self.conv_ch(x1)
out = torch.cat([x2, x1], dim=1)
out = self.conv(out)
return out
================================================
FILE: model/utnet.py
================================================
import torch
import torch.nn as nn
from .unet_utils import up_block, down_block
from .conv_trans_utils import *
import pdb
class UTNet(nn.Module):
def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):
super().__init__()
self.aux_loss = aux_loss
self.inc = [BasicBlock(in_chan, base_chan)]
if '0' in block_list:
self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
self.up4 = up_block_trans(2*base_chan, base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-4], dim_head=base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.inc.append(BasicBlock(base_chan, base_chan))
self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)
self.inc = nn.Sequential(*self.inc)
if '1' in block_list:
self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
self.up3 = up_block_trans(4*base_chan, 2*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-3], dim_head=2*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)
self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)
if '2' in block_list:
self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
self.up2 = up_block_trans(8*base_chan, 4*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-2], dim_head=4*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)
self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)
if '3' in block_list:
self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
self.up1 = up_block_trans(16*base_chan, 8*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-1], dim_head=8*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)
self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)
if '4' in block_list:
self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)
self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)
if aux_loss:
self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)
self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)
self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
if self.aux_loss:
out = self.up1(x5, x4)
out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up2(out, x3)
out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up3(out, x2)
out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up4(out, x1)
out = self.outc(out)
return out, out3, out2, out1
else:
out = self.up1(x5, x4)
out = self.up2(out, x3)
out = self.up3(out, x2)
out = self.up4(out, x1)
out = self.outc(out)
return out
class UTNet_Encoderonly(nn.Module):
def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):
super().__init__()
self.aux_loss = aux_loss
self.inc = [BasicBlock(in_chan, base_chan)]
if '0' in block_list:
self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
else:
self.inc.append(BasicBlock(base_chan, base_chan))
self.inc = nn.Sequential(*self.inc)
if '1' in block_list:
self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)
if '2' in block_list:
self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)
if '3' in block_list:
self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)
if '4' in block_list:
self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
else:
self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)
self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)
self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)
self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)
self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)
self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)
if aux_loss:
self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)
self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)
self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
if self.aux_loss:
out = self.up1(x5, x4)
out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up2(out, x3)
out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up3(out, x2)
out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)
out = self.up4(out, x1)
out = self.outc(out)
return out, out3, out2, out1
else:
out = self.up1(x5, x4)
out = self.up2(out, x3)
out = self.up3(out, x2)
out = self.up4(out, x1)
out = self.outc(out)
return out
================================================
FILE: train_deep.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from model.utnet import UTNet, UTNet_Encoderonly
from dataset_domain import CMRDataset
from torch.utils import data
from losses import DiceLoss
from utils.utils import *
from utils import metrics
from optparse import OptionParser
import SimpleITK as sitk
from torch.utils.tensorboard import SummaryWriter
import time
import math
import os
import sys
import pdb
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
DEBUG = False
def train_net(net, options):
data_path = options.data_path
trainset = CMRDataset(data_path, mode='train', domain=options.domain, debug=DEBUG, scale=options.scale, rotate=options.rotate, crop_size=options.crop_size)
trainLoader = data.DataLoader(trainset, batch_size=options.batch_size, shuffle=True, num_workers=16)
testset_A = CMRDataset(data_path, mode='test', domain='A', debug=DEBUG, crop_size=options.crop_size)
testLoader_A = data.DataLoader(testset_A, batch_size=1, shuffle=False, num_workers=2)
testset_B = CMRDataset(data_path, mode='test', domain='B', debug=DEBUG, crop_size=options.crop_size)
testLoader_B = data.DataLoader(testset_B, batch_size=1, shuffle=False, num_workers=2)
testset_C = CMRDataset(data_path, mode='test', domain='C', debug=DEBUG, crop_size=options.crop_size)
testLoader_C = data.DataLoader(testset_C, batch_size=1, shuffle=False, num_workers=2)
testset_D = CMRDataset(data_path, mode='test', domain='D', debug=DEBUG, crop_size=options.crop_size)
testLoader_D = data.DataLoader(testset_D, batch_size=1, shuffle=False, num_workers=2)
writer = SummaryWriter(options.log_path + options.unique_name)
optimizer = optim.SGD(net.parameters(), lr=options.lr, momentum=0.9, weight_decay=options.weight_decay)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(options.weight).cuda())
criterion_dl = DiceLoss()
best_dice = 0
for epoch in range(options.epochs):
print('Starting epoch {}/{}'.format(epoch+1, options.epochs))
epoch_loss = 0
exp_scheduler = exp_lr_scheduler_with_warmup(optimizer, init_lr=options.lr, epoch=epoch, warmup_epoch=5, max_epoch=options.epochs)
print('current lr:', exp_scheduler)
for i, (img, label) in enumerate(trainLoader, 0):
img = img.cuda()
label = label.cuda()
end = time.time()
net.train()
optimizer.zero_grad()
result = net(img)
loss = 0
if isinstance(result, tuple) or isinstance(result, list):
for j in range(len(result)):
loss += options.aux_weight[j] * (criterion(result[j], label.squeeze(1)) + criterion_dl(result[j], label))
else:
loss = criterion(result, label.squeeze(1)) + criterion_dl(result, label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
batch_time = time.time() - end
print('batch loss: %.5f, batch_time:%.5f'%(loss.item(), batch_time))
print('[epoch %d] epoch loss: %.5f'%(epoch+1, epoch_loss/(i+1)))
writer.add_scalar('Train/Loss', epoch_loss/(i+1), epoch+1)
writer.add_scalar('LR', exp_scheduler, epoch+1)
if os.path.isdir('%s%s/'%(options.cp_path, options.unique_name)):
pass
else:
os.mkdir('%s%s/'%(options.cp_path, options.unique_name))
if epoch % 20 == 0 or epoch > options.epochs-10:
torch.save(net.state_dict(), '%s%s/CP%d.pth'%(options.cp_path, options.unique_name, epoch))
if (epoch+1) >90 or (epoch+1) % 10 == 0:
dice_list_A, ASD_list_A, HD_list_A = validation(net, testLoader_A, options)
log_evaluation_result(writer, dice_list_A, ASD_list_A, HD_list_A, 'A', epoch)
dice_list_B, ASD_list_B, HD_list_B = validation(net, testLoader_B, options)
log_evaluation_result(writer, dice_list_B, ASD_list_B, HD_list_B, 'B', epoch)
dice_list_C, ASD_list_C, HD_list_C = validation(net, testLoader_C, options)
log_evaluation_result(writer, dice_list_C, ASD_list_C, HD_list_C, 'C', epoch)
dice_list_D, ASD_list_D, HD_list_D = validation(net, testLoader_D, options)
log_evaluation_result(writer, dice_list_D, ASD_list_D, HD_list_D, 'D', epoch)
AVG_dice_list = 20 * dice_list_A + 50 * dice_list_B + 50 * dice_list_C + 50 * dice_list_D
AVG_dice_list /= 170
AVG_ASD_list = 20 * ASD_list_A + 50 * ASD_list_B + 50 * ASD_list_C + 50 * ASD_list_D
AVG_ASD_list /= 170
AVG_HD_list = 20 * HD_list_A + 50 * HD_list_B + 50 * HD_list_C + 50 * HD_list_D
AVG_HD_list /= 170
log_evaluation_result(writer, AVG_dice_list, AVG_ASD_list, AVG_HD_list, 'mean', epoch)
if dice_list_A.mean() >= best_dice:
best_dice = dice_list_A.mean()
torch.save(net.state_dict(), '%s%s/best.pth'%(options.cp_path, options.unique_name))
print('save done')
print('dice: %.5f/best dice: %.5f'%(dice_list_A.mean(), best_dice))
def validation(net, test_loader, options):
net.eval()
dice_list = np.zeros(3)
ASD_list = np.zeros(3)
HD_list = np.zeros(3)
counter = 0
with torch.no_grad():
for i, (data, label, spacing) in enumerate(test_loader):
inputs, labels = data.float().cuda(), label.long().cuda()
inputs = inputs.permute(1, 0, 2, 3)
labels = labels.permute(1, 0, 2, 3)
pred = net(inputs)
if options.model == 'FCN_Res50' or options.model == 'FCN_Res101':
pred = pred['out']
elif isinstance(pred, tuple):
pred = pred[0]
pred = F.softmax(pred, dim=1)
_, label_pred = torch.max(pred, dim=1)
tmp_ASD_list, tmp_HD_list = cal_distance(label_pred, labels, spacing)
ASD_list += np.clip(np.nan_to_num(tmp_ASD_list, nan=500), 0, 500)
HD_list += np.clip(np.nan_to_num(tmp_HD_list, nan=500), 0, 500)
label_pred = label_pred.view(-1, 1)
label_true = labels.view(-1, 1)
dice, _, _ = cal_dice(label_pred, label_true, 4)
dice_list += dice.cpu().numpy()[1:]
counter += 1
dice_list /= counter
avg_dice = dice_list.mean()
ASD_list /= counter
HD_list /= counter
return dice_list , ASD_list, HD_list
def cal_distance(label_pred, label_true, spacing):
label_pred = label_pred.squeeze(1).cpu().numpy()
label_true = label_true.squeeze(1).cpu().numpy()
spacing = spacing.numpy()[0]
ASD_list = np.zeros(3)
HD_list = np.zeros(3)
for i in range(3):
tmp_surface = metrics.compute_surface_distances(label_true==(i+1), label_pred==(i+1), spacing)
dis_gt_to_pred, dis_pred_to_gt = metrics.compute_average_surface_distance(tmp_surface)
ASD_list[i] = (dis_gt_to_pred + dis_pred_to_gt) / 2
HD = metrics.compute_robust_hausdorff(tmp_surface, 100)
HD_list[i] = HD
return ASD_list, HD_list
if __name__ == '__main__':
parser = OptionParser()
def get_comma_separated_int_args(option, opt, value, parser):
value_list = value.split(',')
value_list = [int(i) for i in value_list]
setattr(parser.values, option.dest, value_list)
parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', help='number of epochs')
parser.add_option('-b', '--batch_size', dest='batch_size', default=32, type='int', help='batch size')
parser.add_option('-l', '--learning-rate', dest='lr', default=0.05, type='float', help='learning rate')
parser.add_option('-c', '--resume', type='str', dest='load', default=False, help='load pretrained model')
parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', default='./checkpoint/', help='checkpoint path')
parser.add_option('--data_path', type='str', dest='data_path', default='/research/cbim/vast/yg397/vision_transformer/dataset/resampled_dataset/', help='dataset path')
parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/', help='log path')
parser.add_option('-m', type='str', dest='model', default='UTNet', help='use which model')
parser.add_option('--num_class', type='int', dest='num_class', default=4, help='number of segmentation classes')
parser.add_option('--base_chan', type='int', dest='base_chan', default=32, help='number of channels of first expansion in UNet')
parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test', help='unique experiment name')
parser.add_option('--rlt', type='float', dest='rlt', default=1, help='relation between CE/FL and dice')
parser.add_option('--weight', type='float', dest='weight',
default=[0.5,1,1,1] , help='weight each class in loss function')
parser.add_option('--weight_decay', type='float', dest='weight_decay',
default=0.0001)
parser.add_option('--scale', type='float', dest='scale', default=0.30)
parser.add_option('--rotate', type='float', dest='rotate', default=180)
parser.add_option('--crop_size', type='int', dest='crop_size', default=256)
parser.add_option('--domain', type='str', dest='domain', default='A')
parser.add_option('--aux_weight', type='float', dest='aux_weight', default=[1, 0.4, 0.2, 0.1])
parser.add_option('--reduce_size', dest='reduce_size', default=8, type='int')
parser.add_option('--block_list', dest='block_list', default='1234', type='str')
parser.add_option('--num_blocks', dest='num_blocks', default=[1,1,1,1], type='string', action='callback', callback=get_comma_separated_int_args)
parser.add_option('--aux_loss', dest='aux_loss', action='store_true', help='using aux loss for deep supervision')
parser.add_option('--gpu', type='str', dest='gpu', default='0')
options, args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = options.gpu
print('Using model:', options.model)
if options.model == 'UTNet':
net = UTNet(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
elif options.model == 'UTNet_encoder':
# Apply transformer blocks only in the encoder
net = UTNet_Encoderonly(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
elif options.model =='TransUNet':
from model.transunet import VisionTransformer as ViT_seg
from model.transunet import CONFIGS as CONFIGS_ViT_seg
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 4
config_vit.n_skip = 3
config_vit.patches.grid = (int(256/16), int(256/16))
net = ViT_seg(config_vit, img_size=256, num_classes=4)
#net.load_from(weights=np.load('./initmodel/R50+ViT-B_16.npz')) # uncomment this to use pretrain model download from TransUnet git repo
elif options.model == 'ResNet_UTNet':
from model.resnet_utnet import ResNet_UTNet
net = ResNet_UTNet(1, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True)
elif options.model == 'SwinUNet':
from model.swin_unet import SwinUnet, SwinUnet_config
config = SwinUnet_config()
net = SwinUnet(config, img_size=224, num_classes=options.num_class)
net.load_from('./initmodel/swin_tiny_patch4_window7_224.pth')
else:
raise NotImplementedError(options.model + " has not been implemented")
if options.load:
net.load_state_dict(torch.load(options.load))
print('Model loaded from {}'.format(options.load))
param_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(net)
print(param_num)
net.cuda()
train_net(net, options)
print('done')
sys.exit(0)
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/lookup_tables.py
================================================
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lookup tables used by surface distance metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
ENCODE_NEIGHBOURHOOD_3D_KERNEL = np.array([[[128, 64], [32, 16]], [[8, 4],
[2, 1]]])
# _NEIGHBOUR_CODE_TO_NORMALS is a lookup table.
# For every binary neighbour code
# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)
# it contains the surface normals of the triangles (called "surfel" for
# "surface element" in the following). The length of the normal
# vector encodes the surfel area.
#
# created using the marching_cube algorithm
# see e.g. https://en.wikipedia.org/wiki/Marching_cubes
# pylint: disable=line-too-long
_NEIGHBOUR_CODE_TO_NORMALS = [
[[0, 0, 0]],
[[0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
[[0.125, -0.125, 0.125]],
[[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
[[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
[[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],
[[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],
[[0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],
[[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
[[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
[[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
[[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],
[[0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
[[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],
[[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
[[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
[[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
[[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
[[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],
[[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],
[[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],
[[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
[[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
[[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
[[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
[[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
[[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
[[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
[[-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],
[[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
[[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
[[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
[[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
[[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
[[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
[[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
[[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],
[[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],
[[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
[[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
[[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
[[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
[[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
[[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],
[[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
[[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
[[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
[[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
[[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
[[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
[[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],
[[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
[[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
[[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
[[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
[[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
[[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
[[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
[[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
[[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],
[[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
[[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
[[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
[[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
[[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
[[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],
[[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
[[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
[[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
[[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
[[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
[[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
[[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
[[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
[[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
[[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
[[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],
[[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],
[[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],
[[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
[[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
[[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
[[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
[[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
[[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
[[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],
[[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
[[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
[[0.125, -0.125, 0.125]],
[[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],
[[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
[[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
[[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
[[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
[[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],
[[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
[[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],
[[0.125, -0.125, -0.125]],
[[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],
[[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
[[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],
[[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
[[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
[[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
[[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],
[[-0.125, 0.125, 0.125]],
[[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
[[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
[[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
[[0.125, -0.125, 0.125]],
[[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
[[-0.125, -0.125, 0.125]],
[[0.125, 0.125, 0.125]],
[[0, 0, 0]]]
# pylint: enable=line-too-long
def create_table_neighbour_code_to_surface_area(spacing_mm):
"""Returns an array mapping neighbourhood code to the surface elements area.
Note that the normals encode the initial surface area. This function computes
the area corresponding to the given `spacing_mm`.
Args:
spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2
direction.
"""
# compute the area for all 256 possible surface elements
# (given a 2x2x2 neighbourhood) according to the spacing_mm
neighbour_code_to_surface_area = np.zeros([256])
for code in range(256):
normals = np.array(_NEIGHBOUR_CODE_TO_NORMALS[code])
sum_area = 0
for normal_idx in range(normals.shape[0]):
# normal vector
n = np.zeros([3])
n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2]
n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2]
n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1]
area = np.linalg.norm(n)
sum_area += area
neighbour_code_to_surface_area[code] = sum_area
return neighbour_code_to_surface_area
# In the neighbourhood, points are ordered: top left, top right, bottom left,
# bottom right.
ENCODE_NEIGHBOURHOOD_2D_KERNEL = np.array([[8, 4], [2, 1]])
def create_table_neighbour_code_to_contour_length(spacing_mm):
"""Returns an array mapping neighbourhood code to the contour length.
For the list of possible cases and their figures, see page 38 from:
https://nccastaff.bournemouth.ac.uk/jmacey/MastersProjects/MSc14/06/thesis.pdf
In 2D, each point has 4 neighbors. Thus, are 16 configurations. A
configuration is encoded with '1' meaning "inside the object" and '0' "outside
the object". The points are ordered: top left, top right, bottom left, bottom
right.
The x0 axis is assumed vertical downward, and the x1 axis is horizontal to the
right:
(0, 0) --> (0, 1)
|
(1, 0)
Args:
spacing_mm: 2-element list-like structure. Voxel spacing in x0 and x1
directions.
"""
neighbour_code_to_contour_length = np.zeros([16])
vertical = spacing_mm[0]
horizontal = spacing_mm[1]
diag = 0.5 * math.sqrt(spacing_mm[0]**2 + spacing_mm[1]**2)
# pyformat: disable
neighbour_code_to_contour_length[int("00"
"01", 2)] = diag
neighbour_code_to_contour_length[int("00"
"10", 2)] = diag
neighbour_code_to_contour_length[int("00"
"11", 2)] = horizontal
neighbour_code_to_contour_length[int("01"
"00", 2)] = diag
neighbour_code_to_contour_length[int("01"
"01", 2)] = vertical
neighbour_code_to_contour_length[int("01"
"10", 2)] = 2*diag
neighbour_code_to_contour_length[int("01"
"11", 2)] = diag
neighbour_code_to_contour_length[int("10"
"00", 2)] = diag
neighbour_code_to_contour_length[int("10"
"01", 2)] = 2*diag
neighbour_code_to_contour_length[int("10"
"10", 2)] = vertical
neighbour_code_to_contour_length[int("10"
"11", 2)] = diag
neighbour_code_to_contour_length[int("11"
"00", 2)] = horizontal
neighbour_code_to_contour_length[int("11"
"01", 2)] = diag
neighbour_code_to_contour_length[int("11"
"10", 2)] = diag
# pyformat: enable
return neighbour_code_to_contour_length
================================================
FILE: utils/metrics.py
================================================
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module exposing surface distance based measures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from . import lookup_tables # pylint: disable=relative-beyond-top-level
import numpy as np
from scipy import ndimage
def _assert_is_numpy_array(name, array):
"""Raises an exception if `array` is not a numpy array."""
if not isinstance(array, np.ndarray):
raise ValueError("The argument {!r} should be a numpy array, not a "
"{}".format(name, type(array)))
def _check_nd_numpy_array(name, array, num_dims):
"""Raises an exception if `array` is not a `num_dims`-D numpy array."""
if len(array.shape) != num_dims:
raise ValueError("The argument {!r} should be a {}D array, not of "
"shape {}".format(name, num_dims, array.shape))
def _check_2d_numpy_array(name, array):
_check_nd_numpy_array(name, array, num_dims=2)
def _check_3d_numpy_array(name, array):
_check_nd_numpy_array(name, array, num_dims=3)
def _assert_is_bool_numpy_array(name, array):
_assert_is_numpy_array(name, array)
if array.dtype != np.bool:
raise ValueError("The argument {!r} should be a numpy array of type bool, "
"not {}".format(name, array.dtype))
def _compute_bounding_box(mask):
"""Computes the bounding box of the masks.
This function generalizes to arbitrary number of dimensions great or equal
to 1.
Args:
mask: The 2D or 3D numpy mask, where '0' means background and non-zero means
foreground.
Returns:
A tuple:
- The coordinates of the first point of the bounding box (smallest on all
axes), or `None` if the mask contains only zeros.
- The coordinates of the second point of the bounding box (greatest on all
axes), or `None` if the mask contains only zeros.
"""
num_dims = len(mask.shape)
bbox_min = np.zeros(num_dims, np.int64)
bbox_max = np.zeros(num_dims, np.int64)
# max projection to the x0-axis
proj_0 = np.amax(mask, axis=tuple(range(num_dims))[1:])
idx_nonzero_0 = np.nonzero(proj_0)[0]
if len(idx_nonzero_0) == 0: # pylint: disable=g-explicit-length-test
return None, None
bbox_min[0] = np.min(idx_nonzero_0)
bbox_max[0] = np.max(idx_nonzero_0)
# max projection to the i-th-axis for i in {1, ..., num_dims - 1}
for axis in range(1, num_dims):
max_over_axes = list(range(num_dims)) # Python 3 compatible
max_over_axes.pop(axis) # Remove the i-th dimension from the max
max_over_axes = tuple(max_over_axes) # numpy expects a tuple of ints
proj = np.amax(mask, axis=max_over_axes)
idx_nonzero = np.nonzero(proj)[0]
bbox_min[axis] = np.min(idx_nonzero)
bbox_max[axis] = np.max(idx_nonzero)
return bbox_min, bbox_max
def _crop_to_bounding_box(mask, bbox_min, bbox_max):
"""Crops a 2D or 3D mask to the bounding box specified by `bbox_{min,max}`."""
# we need to zeropad the cropped region with 1 voxel at the lower,
# the right (and the back on 3D) sides. This is required to obtain the
# "full" convolution result with the 2x2 (or 2x2x2 in 3D) kernel.
#TODO: This is correct only if the object is interior to the
# bounding box.
cropmask = np.zeros((bbox_max - bbox_min) + 2, np.uint8)
num_dims = len(mask.shape)
# pyformat: disable
if num_dims == 2:
cropmask[0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,
bbox_min[1]:bbox_max[1] + 1]
elif num_dims == 3:
cropmask[0:-1, 0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,
bbox_min[1]:bbox_max[1] + 1,
bbox_min[2]:bbox_max[2] + 1]
# pyformat: enable
else:
assert False
return cropmask
def _sort_distances_surfels(distances, surfel_areas):
"""Sorts the two list with respect to the tuple of (distance, surfel_area).
Args:
distances: The distances from A to B (e.g. `distances_gt_to_pred`).
surfel_areas: The surfel areas for A (e.g. `surfel_areas_gt`).
Returns:
A tuple of the sorted (distances, surfel_areas).
"""
sorted_surfels = np.array(sorted(zip(distances, surfel_areas)))
return sorted_surfels[:, 0], sorted_surfels[:, 1]
def compute_surface_distances(mask_gt,
mask_pred,
spacing_mm):
"""Computes closest distances from all surface points to the other surface.
This function can be applied to 2D or 3D tensors. For 2D, both masks must be
2D and `spacing_mm` must be a 2-element list. For 3D, both masks must be 3D
and `spacing_mm` must be a 3-element list. The description is done for the 2D
case, and the formulation for the 3D case is present is parenthesis,
introduced by "resp.".
Finds all contour elements (resp surface elements "surfels" in 3D) in the
ground truth mask `mask_gt` and the predicted mask `mask_pred`, computes their
length in mm (resp. area in mm^2) and the distance to the closest point on the
other contour (resp. surface). It returns two sorted lists of distances
together with the corresponding contour lengths (resp. surfel areas). If one
of the masks is empty, the corresponding lists are empty and all distances in
the other list are `inf`.
Args:
mask_gt: 2-dim (resp. 3-dim) bool Numpy array. The ground truth mask.
mask_pred: 2-dim (resp. 3-dim) bool Numpy array. The predicted mask.
spacing_mm: 2-element (resp. 3-element) list-like structure. Voxel spacing
in x0 anx x1 (resp. x0, x1 and x2) directions.
Returns:
A dict with:
"distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm
from all ground truth surface elements to the predicted surface,
sorted from smallest to largest.
"distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm
from all predicted surface elements to the ground truth surface,
sorted from smallest to largest.
"surfel_areas_gt": 1-dim numpy array of type float. The length of the
of the ground truth contours in mm (resp. the surface elements area in
mm^2) in the same order as distances_gt_to_pred.
"surfel_areas_pred": 1-dim numpy array of type float. The length of the
of the predicted contours in mm (resp. the surface elements area in
mm^2) in the same order as distances_gt_to_pred.
Raises:
ValueError: If the masks and the `spacing_mm` arguments are of incompatible
shape or type. Or if the masks are not 2D or 3D.
"""
# The terms used in this function are for the 3D case. In particular, surface
# in 2D stands for contours in 3D. The surface elements in 3D correspond to
# the line elements in 2D.
_assert_is_bool_numpy_array("mask_gt", mask_gt)
_assert_is_bool_numpy_array("mask_pred", mask_pred)
if not len(mask_gt.shape) == len(mask_pred.shape) == len(spacing_mm):
raise ValueError("The arguments must be of compatible shape. Got mask_gt "
"with {} dimensions ({}) and mask_pred with {} dimensions "
"({}), while the spacing_mm was {} elements.".format(
len(mask_gt.shape),
mask_gt.shape, len(mask_pred.shape), mask_pred.shape,
len(spacing_mm)))
num_dims = len(spacing_mm)
if num_dims == 2:
_check_2d_numpy_array("mask_gt", mask_gt)
_check_2d_numpy_array("mask_pred", mask_pred)
# compute the area for all 16 possible surface elements
# (given a 2x2 neighbourhood) according to the spacing_mm
neighbour_code_to_surface_area = (
lookup_tables.create_table_neighbour_code_to_contour_length(spacing_mm))
kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_2D_KERNEL
full_true_neighbours = 0b1111
elif num_dims == 3:
_check_3d_numpy_array("mask_gt", mask_gt)
_check_3d_numpy_array("mask_pred", mask_pred)
# compute the area for all 256 possible surface elements
# (given a 2x2x2 neighbourhood) according to the spacing_mm
neighbour_code_to_surface_area = (
lookup_tables.create_table_neighbour_code_to_surface_area(spacing_mm))
kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_3D_KERNEL
full_true_neighbours = 0b11111111
else:
raise ValueError("Only 2D and 3D masks are supported, not "
"{}D.".format(num_dims))
# compute the bounding box of the masks to trim the volume to the smallest
# possible processing subvolume
bbox_min, bbox_max = _compute_bounding_box(mask_gt | mask_pred)
# Both the min/max bbox are None at the same time, so we only check one.
if bbox_min is None:
return {
"distances_gt_to_pred": np.array([]),
"distances_pred_to_gt": np.array([]),
"surfel_areas_gt": np.array([]),
"surfel_areas_pred": np.array([]),
}
# crop the processing subvolume.
cropmask_gt = _crop_to_bounding_box(mask_gt, bbox_min, bbox_max)
cropmask_pred = _crop_to_bounding_box(mask_pred, bbox_min, bbox_max)
# compute the neighbour code (local binary pattern) for each voxel
# the resulting arrays are spacially shifted by minus half a voxel in each
# axis.
# i.e. the points are located at the corners of the original voxels
neighbour_code_map_gt = ndimage.filters.correlate(
cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)
neighbour_code_map_pred = ndimage.filters.correlate(
cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)
# create masks with the surface voxels
borders_gt = ((neighbour_code_map_gt != 0) &
(neighbour_code_map_gt != full_true_neighbours))
borders_pred = ((neighbour_code_map_pred != 0) &
(neighbour_code_map_pred != full_true_neighbours))
# compute the distance transform (closest distance of each voxel to the
# surface voxels)
if borders_gt.any():
distmap_gt = ndimage.morphology.distance_transform_edt(
~borders_gt, sampling=spacing_mm)
else:
distmap_gt = np.Inf * np.ones(borders_gt.shape)
if borders_pred.any():
distmap_pred = ndimage.morphology.distance_transform_edt(
~borders_pred, sampling=spacing_mm)
else:
distmap_pred = np.Inf * np.ones(borders_pred.shape)
# compute the area of each surface element
surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]
surface_area_map_pred = neighbour_code_to_surface_area[
neighbour_code_map_pred]
# create a list of all surface elements with distance and area
distances_gt_to_pred = distmap_pred[borders_gt]
distances_pred_to_gt = distmap_gt[borders_pred]
surfel_areas_gt = surface_area_map_gt[borders_gt]
surfel_areas_pred = surface_area_map_pred[borders_pred]
# sort them by distance
if distances_gt_to_pred.shape != (0,):
distances_gt_to_pred, surfel_areas_gt = _sort_distances_surfels(
distances_gt_to_pred, surfel_areas_gt)
if distances_pred_to_gt.shape != (0,):
distances_pred_to_gt, surfel_areas_pred = _sort_distances_surfels(
distances_pred_to_gt, surfel_areas_pred)
return {
"distances_gt_to_pred": distances_gt_to_pred,
"distances_pred_to_gt": distances_pred_to_gt,
"surfel_areas_gt": surfel_areas_gt,
"surfel_areas_pred": surfel_areas_pred,
}
def compute_average_surface_distance(surface_distances):
"""Returns the average surface distance.
Computes the average surface distances by correctly taking the area of each
surface element into account. Call compute_surface_distances(...) before, to
obtain the `surface_distances` dict.
Args:
surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
"surfel_areas_gt", "surfel_areas_pred" created by
compute_surface_distances()
Returns:
A tuple with two float values:
- the average distance (in mm) from the ground truth surface to the
predicted surface
- the average distance from the predicted surface to the ground truth
surface.
"""
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
surfel_areas_gt = surface_distances["surfel_areas_gt"]
surfel_areas_pred = surface_distances["surfel_areas_pred"]
average_distance_gt_to_pred = (
np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt))
average_distance_pred_to_gt = (
np.sum(distances_pred_to_gt * surfel_areas_pred) /
np.sum(surfel_areas_pred))
return (average_distance_gt_to_pred, average_distance_pred_to_gt)
def compute_robust_hausdorff(surface_distances, percent):
"""Computes the robust Hausdorff distance.
Computes the robust Hausdorff distance. "Robust", because it uses the
`percent` percentile of the distances instead of the maximum distance. The
percentage is computed by correctly taking the area of each surface element
into account.
Args:
surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
"surfel_areas_gt", "surfel_areas_pred" created by
compute_surface_distances()
percent: a float value between 0 and 100.
Returns:
a float value. The robust Hausdorff distance in mm.
"""
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
surfel_areas_gt = surface_distances["surfel_areas_gt"]
surfel_areas_pred = surface_distances["surfel_areas_pred"]
if len(distances_gt_to_pred) > 0: # pylint: disable=g-explicit-length-test
surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)
idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)
perc_distance_gt_to_pred = distances_gt_to_pred[
min(idx, len(distances_gt_to_pred)-1)]
else:
perc_distance_gt_to_pred = np.Inf
if len(distances_pred_to_gt) > 0: # pylint: disable=g-explicit-length-test
surfel_areas_cum_pred = (np.cumsum(surfel_areas_pred) /
np.sum(surfel_areas_pred))
idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)
perc_distance_pred_to_gt = distances_pred_to_gt[
min(idx, len(distances_pred_to_gt)-1)]
else:
perc_distance_pred_to_gt = np.Inf
return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt)
def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
"""Computes the overlap of the surfaces at a specified tolerance.
Computes the overlap of the ground truth surface with the predicted surface
and vice versa allowing a specified tolerance (maximum surface-to-surface
distance that is regarded as overlapping). The overlapping fraction is
computed by correctly taking the area of each surface element into account.
Args:
surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
"surfel_areas_gt", "surfel_areas_pred" created by
compute_surface_distances()
tolerance_mm: a float value. The tolerance in mm
Returns:
A tuple of two float values. The overlap fraction in [0.0, 1.0] of the
ground truth surface with the predicted surface and vice versa.
"""
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
surfel_areas_gt = surface_distances["surfel_areas_gt"]
surfel_areas_pred = surface_distances["surfel_areas_pred"]
rel_overlap_gt = (
np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) /
np.sum(surfel_areas_gt))
rel_overlap_pred = (
np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) /
np.sum(surfel_areas_pred))
return (rel_overlap_gt, rel_overlap_pred)
def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
"""Computes the _surface_ DICE coefficient at a specified tolerance.
Computes the _surface_ DICE coefficient at a specified tolerance. Not to be
confused with the standard _volumetric_ DICE coefficient. The surface DICE
measures the overlap of two surfaces instead of two volumes. A surface
element is counted as overlapping (or touching), when the closest distance to
the other surface is less or equal to the specified tolerance. The DICE
coefficient is in the range between 0.0 (no overlap) to 1.0 (perfect overlap).
Args:
surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
"surfel_areas_gt", "surfel_areas_pred" created by
compute_surface_distances()
tolerance_mm: a float value. The tolerance in mm
Returns:
A float value. The surface DICE coefficient in [0.0, 1.0].
"""
distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
surfel_areas_gt = surface_distances["surfel_areas_gt"]
surfel_areas_pred = surface_distances["surfel_areas_pred"]
overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])
overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])
surface_dice = (overlap_gt + overlap_pred) / (
np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))
return surface_dice
def compute_dice_coefficient(mask_gt, mask_pred):
"""Computes soerensen-dice coefficient.
compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
and the predicted mask `mask_pred`.
Args:
mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
mask_pred: 3-dim Numpy array of type bool. The predicted mask.
Returns:
the dice coeffcient as float. If both masks are empty, the result is NaN.
"""
volume_sum = mask_gt.sum() + mask_pred.sum()
if volume_sum == 0:
return np.NaN
volume_intersect = (mask_gt & mask_pred).sum()
return 2*volume_intersect / volume_sum
================================================
FILE: utils/utils.py
================================================
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import SimpleITK as sitk
import pdb
def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, epoch):
writer.add_scalar('Test_Dice/%s_AVG'%name, dice_list.mean(), epoch+1)
for idx in range(3):
writer.add_scalar('Test_Dice/%s_Dice%d'%(name, idx+1), dice_list[idx], epoch+1)
writer.add_scalar('Test_ASD/%s_AVG'%name, ASD_list.mean(), epoch+1)
for idx in range(3):
writer.add_scalar('Test_ASD/%s_ASD%d'%(name, idx+1), ASD_list[idx], epoch+1)
writer.add_scalar('Test_HD/%s_AVG'%name, HD_list.mean(), epoch+1)
for idx in range(3):
writer.add_scalar('Test_HD/%s_HD%d'%(name, idx+1), HD_list[idx], epoch+1)
def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1):
if epoch >= 0 and epoch <= warmup_epoch:
lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))
if epoch == warmup_epoch:
lr = init_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
flag = False
for i in range(len(lr_decay_epoch)):
if epoch == lr_decay_epoch[i]:
flag = True
break
if flag == True:
lr = init_lr * gamma**(i+1)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
return optimizer.param_groups[0]['lr']
return lr
def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch):
if epoch >= 0 and epoch <= warmup_epoch:
lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))
if epoch == warmup_epoch:
lr = init_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
else:
lr = init_lr * (1 - epoch / max_epoch)**0.9
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def cal_dice(pred, target, C):
N = pred.shape[0]
target_mask = target.data.new(N, C).fill_(0)
target_mask.scatter_(1, target, 1.)
pred_mask = pred.data.new(N, C).fill_(0)
pred_mask.scatter_(1, pred, 1.)
intersection= pred_mask * target_mask
summ = pred_mask + target_mask
intersection = intersection.sum(0).type(torch.float32)
summ = summ.sum(0).type(torch.float32)
eps = torch.rand(C, dtype=torch.float32)
eps = eps.fill_(1e-7)
summ += eps.cuda()
dice = 2 * intersection / summ
return dice, intersection, summ
def cal_asd(itkPred, itkGT):
reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkGT, squaredDistance=False))
reference_surface = sitk.LabelContour(itkGT)
statistics_image_filter = sitk.StatisticsImageFilter()
statistics_image_filter.Execute(reference_surface)
num_reference_surface_pixels = int(statistics_image_filter.GetSum())
segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkPred, squaredDistance=False))
segmented_surface = sitk.LabelContour(itkPred)
seg2ref_distance_map = reference_distance_map * sitk.Cast(segmented_surface, sitk.sitkFloat32)
ref2seg_distance_map = segmented_distance_map * sitk.Cast(reference_surface, sitk.sitkFloat32)
statistics_image_filter.Execute(segmented_surface)
num_segmented_surface_pixels = int(statistics_image_filter.GetSum())
seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0])
seg2ref_distances = seg2ref_distances + \
list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))
ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0])
ref2seg_distances = ref2seg_distances + \
list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))
all_surface_distances = seg2ref_distances + ref2seg_distances
ASD = np.mean(all_surface_distances)
return ASD
gitextract_vjhzdth0/
├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess.py
├── dataset/
│ ├── Testing_A.csv
│ ├── Testing_B.csv
│ ├── Testing_C.csv
│ ├── Testing_D.csv
│ ├── Training_A.csv
│ ├── Training_B.csv
│ ├── Training_C.csv
│ ├── Training_D.csv
│ └── info.csv
├── dataset_domain.py
├── losses.py
├── model/
│ ├── __init__.py
│ ├── conv_trans_utils.py
│ ├── resnet_utnet.py
│ ├── swin_unet.py
│ ├── transunet.py
│ ├── unet_utils.py
│ └── utnet.py
├── train_deep.py
└── utils/
├── __init__.py
├── lookup_tables.py
├── metrics.py
└── utils.py
SYMBOL INDEX (230 symbols across 13 files)
FILE: data_preprocess.py
function ResampleXYZAxis (line 7) | def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear):
function ResampleFullImageToRef (line 23) | def ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighb...
function ResampleCMRImage (line 37) | def ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, ta...
FILE: dataset_domain.py
class CMRDataset (line 13) | class CMRDataset(Dataset):
method __init__ (line 14) | def __init__(self, dataset_dir, mode='train', domain='A', crop_size=25...
method __len__ (line 103) | def __len__(self):
method preprocess (line 106) | def preprocess(self, itk_img, itk_lab):
method __getitem__ (line 131) | def __getitem__(self, idx):
method randcrop (line 162) | def randcrop(self, img, label):
method center_crop (line 177) | def center_crop(self, img, label):
method random_zoom_rotate (line 191) | def random_zoom_rotate(self, img, label):
FILE: losses.py
class DiceLoss (line 8) | class DiceLoss(nn.Module):
method __init__ (line 10) | def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):
method forward (line 18) | def forward(self, preds, targets):
class FocalLoss (line 60) | class FocalLoss(nn.Module):
method __init__ (line 61) | def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
method forward (line 72) | def forward(self, preds, targets):
FILE: model/conv_trans_utils.py
function conv3x3 (line 9) | def conv3x3(in_planes, out_planes, stride=1):
function conv1x1 (line 11) | def conv1x1(in_planes, out_planes, stride=1):
class depthwise_separable_conv (line 14) | class depthwise_separable_conv(nn.Module):
method __init__ (line 15) | def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, ...
method forward (line 20) | def forward(self, x):
class Mlp (line 26) | class Mlp(nn.Module):
method __init__ (line 27) | def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU,...
method forward (line 37) | def forward(self, x):
class BasicBlock (line 46) | class BasicBlock(nn.Module):
method __init__ (line 48) | def __init__(self, inplanes, planes, stride=1):
method forward (line 65) | def forward(self, x):
class BasicTransBlock (line 80) | class BasicTransBlock(nn.Module):
method __init__ (line 82) | def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0.,...
method forward (line 93) | def forward(self, x):
class BasicTransDecoderBlock (line 109) | class BasicTransDecoderBlock(nn.Module):
method __init__ (line 111) | def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_...
method forward (line 124) | def forward(self, x1, x2):
class LinearAttention (line 150) | class LinearAttention(nn.Module):
method __init__ (line 152) | def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=...
method forward (line 180) | def forward(self, x):
class LinearAttentionDecoder (line 217) | class LinearAttentionDecoder(nn.Module):
method __init__ (line 219) | def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0....
method forward (line 247) | def forward(self, q, x):
class RelativePositionEmbedding (line 284) | class RelativePositionEmbedding(nn.Module):
method __init__ (line 286) | def __init__(self, dim, shape):
method forward (line 303) | def forward(self, q, Nh, H, W, dim_head):
method relative_logits_1d (line 316) | def relative_logits_1d(self, q, rel_k, case):
class RelativePositionBias (line 344) | class RelativePositionBias(nn.Module):
method __init__ (line 348) | def __init__(self, num_heads, h, w):
method forward (line 371) | def forward(self, H, W):
class down_block_trans (line 385) | class down_block_trans(nn.Module):
method __init__ (line 386) | def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool...
method forward (line 411) | def forward(self, x):
class up_block_trans (line 418) | class up_block_trans(nn.Module):
method __init__ (line 419) | def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4...
method forward (line 439) | def forward(self, x1, x2):
class block_trans (line 447) | class block_trans(nn.Module):
method __init__ (line 448) | def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0...
method forward (line 462) | def forward(self, x):
FILE: model/resnet_utnet.py
class ResNet_UTNet (line 13) | class ResNet_UTNet(nn.Module):
method __init__ (line 14) | def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', ...
method forward (line 51) | def forward(self, x):
FILE: model/swin_unet.py
class Mlp (line 48) | class Mlp(nn.Module):
method __init__ (line 50) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 68) | def forward(self, x):
function window_partition (line 86) | def window_partition(x, window_size):
function window_reverse (line 116) | def window_reverse(windows, window_size, H, W):
class WindowAttention (line 150) | class WindowAttention(nn.Module):
method __init__ (line 178) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
method forward (line 246) | def forward(self, x, mask=None):
method extra_repr (line 312) | def extra_repr(self) -> str:
method flops (line 318) | def flops(self, N):
class SwinTransformerBlock (line 346) | class SwinTransformerBlock(nn.Module):
method __init__ (line 384) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
method forward (line 486) | def forward(self, x):
method extra_repr (line 564) | def extra_repr(self) -> str:
method flops (line 571) | def flops(self):
class PatchMerging (line 601) | class PatchMerging(nn.Module):
method __init__ (line 619) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
method forward (line 633) | def forward(self, x):
method extra_repr (line 679) | def extra_repr(self) -> str:
method flops (line 685) | def flops(self):
class PatchExpand (line 697) | class PatchExpand(nn.Module):
method __init__ (line 699) | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.L...
method forward (line 713) | def forward(self, x):
class FinalPatchExpand_X4 (line 745) | class FinalPatchExpand_X4(nn.Module):
method __init__ (line 747) | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.L...
method forward (line 765) | def forward(self, x):
class BasicLayer (line 797) | class BasicLayer(nn.Module):
method __init__ (line 837) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 893) | def forward(self, x):
method extra_repr (line 913) | def extra_repr(self) -> str:
method flops (line 919) | def flops(self):
class BasicLayer_up (line 935) | class BasicLayer_up(nn.Module):
method __init__ (line 975) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
method forward (line 1031) | def forward(self, x):
class PatchEmbed (line 1051) | class PatchEmbed(nn.Module):
method __init__ (line 1073) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
method forward (line 1111) | def forward(self, x):
method flops (line 1130) | def flops(self):
class SwinTransformerSys (line 1146) | class SwinTransformerSys(nn.Module):
method __init__ (line 1198) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
method _init_weights (line 1384) | def _init_weights(self, m):
method no_weight_decay (line 1404) | def no_weight_decay(self):
method no_weight_decay_keywords (line 1412) | def no_weight_decay_keywords(self):
method forward_features (line 1420) | def forward_features(self, x):
method forward_up_features (line 1452) | def forward_up_features(self, x, x_downsample):
method up_x4 (line 1478) | def up_x4(self, x):
method forward (line 1504) | def forward(self, x):
method flops (line 1518) | def flops(self):
class SwinUnet_config (line 1541) | class SwinUnet_config():
method __init__ (line 1542) | def __init__(self):
class SwinUnet (line 1559) | class SwinUnet(nn.Module):
method __init__ (line 1561) | def __init__(self, config, img_size=224, num_classes=21843, zero_head=...
method forward (line 1607) | def forward(self, x):
method load_from (line 1619) | def load_from(self, pretrained_path):
FILE: model/transunet.py
function np2th (line 72) | def np2th(weights, conv=False):
function swish (line 86) | def swish(x):
class Attention (line 100) | class Attention(nn.Module):
method __init__ (line 102) | def __init__(self, config, vis):
method transpose_for_scores (line 136) | def transpose_for_scores(self, x):
method forward (line 146) | def forward(self, hidden_states):
class Mlp (line 194) | class Mlp(nn.Module):
method __init__ (line 196) | def __init__(self, config):
method _init_weights (line 214) | def _init_weights(self):
method forward (line 226) | def forward(self, x):
class Embeddings (line 244) | class Embeddings(nn.Module):
method __init__ (line 250) | def __init__(self, config, img_size, in_channels=3):
method forward (line 308) | def forward(self, x):
class Block (line 336) | class Block(nn.Module):
method __init__ (line 338) | def __init__(self, config, vis):
method forward (line 354) | def forward(self, x):
method load_from (line 378) | def load_from(self, weights, n_block):
class Encoder (line 454) | class Encoder(nn.Module):
method __init__ (line 456) | def __init__(self, config, vis):
method forward (line 474) | def forward(self, hidden_states):
class Transformer (line 494) | class Transformer(nn.Module):
method __init__ (line 496) | def __init__(self, config, img_size, vis):
method forward (line 506) | def forward(self, input_ids):
class Conv2dReLU (line 518) | class Conv2dReLU(nn.Sequential):
method __init__ (line 520) | def __init__(
class DecoderBlock (line 568) | class DecoderBlock(nn.Module):
method __init__ (line 570) | def __init__(
method forward (line 618) | def forward(self, x, skip=None):
class SegmentationHead (line 636) | class SegmentationHead(nn.Sequential):
method __init__ (line 640) | def __init__(self, in_channels, out_channels, kernel_size=3, upsamplin...
class DecoderCup (line 652) | class DecoderCup(nn.Module):
method __init__ (line 654) | def __init__(self, config):
method forward (line 710) | def forward(self, hidden_states, features=None):
class VisionTransformer (line 740) | class VisionTransformer(nn.Module):
method __init__ (line 742) | def __init__(self, config, img_size=224, num_classes=21843, zero_head=...
method forward (line 770) | def forward(self, x):
method load_from (line 786) | def load_from(self, weights):
function get_b16_config (line 885) | def get_b16_config():
function get_testing (line 934) | def get_testing():
function get_r50_b16_config (line 964) | def get_r50_b16_config():
function get_b32_config (line 1002) | def get_b32_config():
function get_l16_config (line 1018) | def get_l16_config():
function get_r50_l16_config (line 1064) | def get_r50_l16_config():
function get_l32_config (line 1098) | def get_l32_config():
function get_h14_config (line 1112) | def get_h14_config():
function np2th (line 1176) | def np2th(weights, conv=False):
class StdConv2d (line 1190) | class StdConv2d(nn.Conv2d):
method forward (line 1194) | def forward(self, x):
function conv3x3 (line 1210) | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
function conv1x1 (line 1220) | def conv1x1(cin, cout, stride=1, bias=False):
class PreActBottleneck (line 1230) | class PreActBottleneck(nn.Module):
method __init__ (line 1238) | def __init__(self, cin, cout=None, cmid=None, stride=1):
method forward (line 1274) | def forward(self, x):
method load_from (line 1306) | def load_from(self, weights, n_block, n_unit):
class ResNetV2 (line 1378) | class ResNetV2(nn.Module):
method __init__ (line 1384) | def __init__(self, block_units, width_factor):
method forward (line 1438) | def forward(self, x):
FILE: model/unet_utils.py
class DoubleConv (line 6) | class DoubleConv(nn.Module):
method __init__ (line 12) | def __init__(self, in_channels, out_channels, mid_channels=None):
method forward (line 38) | def forward(self, x):
class Down (line 46) | class Down(nn.Module):
method __init__ (line 52) | def __init__(self, in_channels, out_channels):
method forward (line 66) | def forward(self, x):
class Up (line 74) | class Up(nn.Module):
method __init__ (line 80) | def __init__(self, in_channels, out_channels, bilinear=True):
method forward (line 104) | def forward(self, x1, x2):
class OutConv (line 134) | class OutConv(nn.Module):
method __init__ (line 136) | def __init__(self, in_channels, out_channels):
method forward (line 144) | def forward(self, x):
function conv3x3 (line 149) | def conv3x3(in_planes, out_planes, stride=1):
function conv1x1 (line 151) | def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock (line 156) | class BasicBlock(nn.Module):
method __init__ (line 158) | def __init__(self, inplanes, planes, stride=1):
method forward (line 175) | def forward(self, x):
class BottleneckBlock (line 190) | class BottleneckBlock(nn.Module):
method __init__ (line 192) | def __init__(self, inplanes, planes, stride=1):
method forward (line 212) | def forward(self, x):
class inconv (line 233) | class inconv(nn.Module):
method __init__ (line 234) | def __init__(self, in_ch, out_ch, bottleneck=False):
method forward (line 244) | def forward(self, x):
class down_block (line 251) | class down_block(nn.Module):
method __init__ (line 252) | def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, ...
method forward (line 274) | def forward(self, x):
class up_block (line 280) | class up_block(nn.Module):
method __init__ (line 281) | def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=Fa...
method forward (line 301) | def forward(self, x1, x2):
FILE: model/utnet.py
class UTNet (line 11) | class UTNet(nn.Module):
method __init__ (line 13) | def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, b...
method forward (line 66) | def forward(self, x):
class UTNet_Encoderonly (line 105) | class UTNet_Encoderonly(nn.Module):
method __init__ (line 107) | def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, b...
method forward (line 158) | def forward(self, x):
FILE: train_deep.py
function train_net (line 28) | def train_net(net, options):
function validation (line 141) | def validation(net, test_loader, options):
function cal_distance (line 187) | def cal_distance(label_pred, label_true, spacing):
function get_comma_separated_int_args (line 210) | def get_comma_separated_int_args(option, opt, value, parser):
FILE: utils/lookup_tables.py
function create_table_neighbour_code_to_surface_area (line 591) | def create_table_neighbour_code_to_surface_area(spacing_mm):
function create_table_neighbour_code_to_contour_length (line 655) | def create_table_neighbour_code_to_contour_length(spacing_mm):
FILE: utils/metrics.py
function _assert_is_numpy_array (line 49) | def _assert_is_numpy_array(name, array):
function _check_nd_numpy_array (line 63) | def _check_nd_numpy_array(name, array, num_dims):
function _check_2d_numpy_array (line 77) | def _check_2d_numpy_array(name, array):
function _check_3d_numpy_array (line 85) | def _check_3d_numpy_array(name, array):
function _assert_is_bool_numpy_array (line 93) | def _assert_is_bool_numpy_array(name, array):
function _compute_bounding_box (line 107) | def _compute_bounding_box(mask):
function _crop_to_bounding_box (line 187) | def _crop_to_bounding_box(mask, bbox_min, bbox_max):
function _sort_distances_surfels (line 237) | def _sort_distances_surfels(distances, surfel_areas):
function compute_surface_distances (line 265) | def compute_surface_distances(mask_gt,
function compute_average_surface_distance (line 579) | def compute_average_surface_distance(surface_distances):
function compute_robust_hausdorff (line 641) | def compute_robust_hausdorff(surface_distances, percent):
function compute_surface_overlap_at_tolerance (line 723) | def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
function compute_surface_dice_at_tolerance (line 785) | def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
function compute_dice_coefficient (line 845) | def compute_dice_coefficient(mask_gt, mask_pred):
FILE: utils/utils.py
function log_evaluation_result (line 9) | def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, ep...
function multistep_lr_scheduler_with_warmup (line 22) | def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup...
function exp_lr_scheduler_with_warmup (line 49) | def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch...
function cal_dice (line 69) | def cal_dice(pred, target, C):
function cal_asd (line 91) | def cal_asd(itkPred, itkGT):
Condensed preview — 27 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (205K chars).
[
{
"path": ".gitignore",
"chars": 76,
"preview": "__pycache__/\n*.pyc\n*_bkp.py\ncheckpoint/\nlog/\nmodel/backup/\ninitmodel/\n*.swp\n"
},
{
"path": "LICENSE",
"chars": 1063,
"preview": "MIT License\n\nCopyright (c) 2021 yhygao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof "
},
{
"path": "README.md",
"chars": 8914,
"preview": "# UTNet (Accepted at MICCAI 2021)\nOfficial implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image"
},
{
"path": "data_preprocess.py",
"chars": 3222,
"preview": "import numpy as np\nimport SimpleITK as sitk\n\nimport os\nimport pdb \n\ndef ResampleXYZAxis(imImage, space=(1., 1., 1.), int"
},
{
"path": "dataset/Testing_A.csv",
"chars": 145,
"preview": "name\nA2L1N6\nB0H7V0\nB3F0V9\nB5F8L9\nB9H8N8\nC4E9I1\nC6E0F9\nC6U4W8\nD1H6U2\nD9F5P1\nE0J7L9\nF0K4T6\nG7S6V0\nG9N5V9\nH1N8S6\nI2J6Z6\nI5L"
},
{
"path": "dataset/Testing_B.csv",
"chars": 355,
"preview": "name\nA2H5K9\nA4A8V9\nA5C2D2\nA5D0G0\nA6J0Y2\nA7E4J0\nB1G9J3\nB3S2Z4\nB4E1K1\nB4S1Y2\nB5L5Y4\nB5T6V0\nB7F5P0\nC7L8Z8\nC8I7P7\nD1R0Y5\nD3Q"
},
{
"path": "dataset/Testing_C.csv",
"chars": 355,
"preview": "name\nA3H5R1\nA5H1Q2\nA8C5E9\nA9F3T5\nA9L7Y7\nB0L3Y2\nB2L0L2\nB8P5Q9\nC0N8P4\nC7M6W0\nC8J7L5\nC8O0P2\nD1H2O9\nD2U0V0\nD3K5Q2\nD5G3W8\nD6N"
},
{
"path": "dataset/Testing_D.csv",
"chars": 355,
"preview": "name\nA1K2P5\nA3P9V7\nA4B9O6\nA4K8R4\nA4R4T0\nA5P5W0\nA5Q1W8\nA6A8H0\nA6B7Y4\nA7F4G2\nB3E2W8\nB6I0T4\nB9G4U2\nC0L7V1\nC5L0R0\nD1L6T4\nD1S"
},
{
"path": "dataset/Training_A.csv",
"chars": 530,
"preview": "name\nA0S9V9\nA1D9Z7\nA1E9Q1\nA2C0I1\nA2N8V0\nA3H1O5\nA4B5U4\nA4J4S4\nA4U9V5\nA6B5G9\nA6D5F9\nA7D9L8\nA7M7P8\nA7O4T6\nB0I2Z0\nB2C2Z7\nB2D"
},
{
"path": "dataset/Training_B.csv",
"chars": 530,
"preview": "name\nA1D0Q7\nA1O8Z3\nA3B7E5\nA5E0T8\nA6M1Q7\nA7G0P5\nA8C9U8\nA8E1F4\nA9C5P4\nA9E3G9\nA9J5Q7\nA9J8W7\nB0N3W8\nB2D9O2\nB2F4K5\nB3D0N1\nB6D"
},
{
"path": "dataset/Training_C.csv",
"chars": 5,
"preview": "name\n"
},
{
"path": "dataset/Training_D.csv",
"chars": 5,
"preview": "name\n"
},
{
"path": "dataset/info.csv",
"chars": 7868,
"preview": "External code,VendorName,Vendor,Centre,ED,ES\nA0S9V9,Siemens,A,1,0,9\nA1D0Q7,Philips,B,2,0,9\nA1D9Z7,Siemens,A,1,22,11\nA1E9"
},
{
"path": "dataset_domain.py",
"chars": 7311,
"preview": "import os\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.utils.data im"
},
{
"path": "losses.py",
"chars": 3414,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport pdb \n\nclass DiceLoss(nn.Mo"
},
{
"path": "model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "model/conv_trans_utils.py",
"chars": 17515,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport pdb\n\n\n\ndef conv3x"
},
{
"path": "model/resnet_utnet.py",
"chars": 2456,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .unet_utils import up_block\nfrom .transunet imp"
},
{
"path": "model/swin_unet.py",
"chars": 36793,
"preview": "# code is borrowed from the original repo and fit into our training framework\n# https://github.com/HuCaoFighting/Swin-Un"
},
{
"path": "model/transunet.py",
"chars": 28847,
"preview": "# The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660e"
},
{
"path": "model/unet_utils.py",
"chars": 6989,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pdb \n\nclass DoubleConv(nn.Module):\n\n \"\"\"(co"
},
{
"path": "model/utnet.py",
"chars": 9492,
"preview": "import torch\nimport torch.nn as nn\n\nfrom .unet_utils import up_block, down_block\nfrom .conv_trans_utils import *\n\nimport"
},
{
"path": "train_deep.py",
"chars": 12648,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import optim\n\nimport numpy as np\nfrom mode"
},
{
"path": "utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utils/lookup_tables.py",
"chars": 23145,
"preview": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n"
},
{
"path": "utils/metrics.py",
"chars": 18805,
"preview": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n"
},
{
"path": "utils/utils.py",
"chars": 4233,
"preview": "import torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport SimpleITK as sitk\nimport pd"
}
]
About this extraction
This page contains the full source code of the yhygao/UTNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 27 files (190.5 KB), approximately 63.2k tokens, and a symbol index with 230 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.