Full Code of yhygao/UTNet for AI

main 6bdcc97e2364 cached
27 files
190.5 KB
63.2k tokens
230 symbols
1 requests
Download .txt
Repository: yhygao/UTNet
Branch: main
Commit: 6bdcc97e2364
Files: 27
Total size: 190.5 KB

Directory structure:
gitextract_vjhzdth0/

├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess.py
├── dataset/
│   ├── Testing_A.csv
│   ├── Testing_B.csv
│   ├── Testing_C.csv
│   ├── Testing_D.csv
│   ├── Training_A.csv
│   ├── Training_B.csv
│   ├── Training_C.csv
│   ├── Training_D.csv
│   └── info.csv
├── dataset_domain.py
├── losses.py
├── model/
│   ├── __init__.py
│   ├── conv_trans_utils.py
│   ├── resnet_utnet.py
│   ├── swin_unet.py
│   ├── transunet.py
│   ├── unet_utils.py
│   └── utnet.py
├── train_deep.py
└── utils/
    ├── __init__.py
    ├── lookup_tables.py
    ├── metrics.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__/
*.pyc
*_bkp.py
checkpoint/
log/
model/backup/
initmodel/
*.swp


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 yhygao

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# UTNet (Accepted at MICCAI 2021)
Official implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation](https://arxiv.org/abs/2107.00781)


## Update
* Our new work, Hermes, has been released on arXiv: [Training Like a Medical Resident: Universal Medical Image Segmentation via Context Prior Learning](https://arxiv.org/pdf/2306.02416.pdf). Inspired by the training of medical residents, we explore universal medical image segmentation, whose goal is to learn from diverse medical imaging sources covering a range of clinical targets, body regions, and image modalities. Following this paradigm, we propose Hermes, a context prior learning approach that addresses the challenges related to the heterogeneity on data, modality, and annotations in the proposed universal paradigm. Code will be released at https://github.com/yhygao/universal-medical-image-segmentation.
* Our new paper, the improved version of UTNet: UTNetV2, is released on Arxiv: [A Multi-scale Transformer for Medical Image Segmentation: Architectures, Model Efficiency, and Benchmarks](https://arxiv.org/abs/2203.00131). The UTNetV2 has an improved architecture for 2D and 3D setting. We also provide a more general framework for CNN and Transformer comparison, including more dataset support, more SOTA model support, see in our new repo:https://github.com/yhygao/CBIM-Medical-Image-Segmentation

Data preprocess code uploaded.


## Introduction 

Transformer architecture has emerged to be successful in a
number of natural language processing tasks. However, its applications
to medical vision remain largely unexplored. In this study, we present
UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing
medical image segmentation. UTNet applies self-attention modules in
both encoder and decoder for capturing long-range dependency at dif-
ferent scales with minimal overhead. To this end, we propose an efficient
self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2)
to approximate O(n). A new self-attention decoder is also proposed to
recover fine-grained details from the skipped connections in the encoder.
Our approach addresses the dilemma that Transformer requires huge
amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks
without a need of pre-training. We have evaluated UTNet on the multi-
label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet
demonstrates superior segmentation performance and robustness against
the state-of-the-art approaches, holding the promise to generalize well on
other medical image segmentations.

![image](https://user-images.githubusercontent.com/55367673/134997310-69c3576d-bbf2-40c8-ad5a-9b3bf3e9e97d.png)
![image](https://user-images.githubusercontent.com/55367673/134997347-a581cda7-7050-48ef-9af3-d4628fefac9a.png)

## Supportting models
UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

## Getting Started

Currently, we only support [M&Ms dataset](https://www.ub.edu/mnms/).

### Prerequisites
```
Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2
```

### Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

### Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

#### UTNet
For default UTNet setting, training with:
```
python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss
```
Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:
```
--block_list 1234 --num_blocks 1,1,1,1
```
Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

```
--block_list 1234 --num_blocks 1,1,4,8
```
Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

```
--block_list 234 --num_blocks 2,4,8
```
Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.


Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

#### TransUNet
We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the [original repo](https://github.com/Beckschen/TransUNet). The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.
```
python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0
```

#### ResNet50-UTNet
For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.
```
python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0
```
Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.
```
--block_list 23 --num_blocks 2,4
```
Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.


#### ResNet50-UNet
If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

```
python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0
```

#### SwinUNet
Download pre-trained model from the [origin repo](https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912).
As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.
```
python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224
```

## Citation
If you find this repo helps, please kindly cite our paper, thanks!
```
@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}

@misc{gao2022datascalable,
      title={A Data-scalable Transformer for Medical Image Segmentation: Architecture, Model Efficiency, and Benchmark}, 
      author={Yunhe Gao and Mu Zhou and Di Liu and Zhennan Yan and Shaoting Zhang and Dimitris N. Metaxas},
      year={2022},
      eprint={2203.00131},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}

```



================================================
FILE: data_preprocess.py
================================================
import numpy as np
import SimpleITK as sitk

import os
import pdb 

def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear):
    identity1 = sitk.Transform(3, sitk.sitkIdentity)
    sp1 = imImage.GetSpacing()
    sz1 = imImage.GetSize()

    sz2 = (int(round(sz1[0]*sp1[0]*1.0/space[0])), int(round(sz1[1]*sp1[1]*1.0/space[1])), int(round(sz1[2]*sp1[2]*1.0/space[2])))

    imRefImage = sitk.Image(sz2, imImage.GetPixelIDValue())
    imRefImage.SetSpacing(space)
    imRefImage.SetOrigin(imImage.GetOrigin())
    imRefImage.SetDirection(imImage.GetDirection())

    imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)

    return imOutImage

def ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighbor):
    identity1 = sitk.Transform(3, sitk.sitkIdentity)

    imRefImage = sitk.Image(imRef.GetSize(), imImage.GetPixelIDValue())
    imRefImage.SetSpacing(imRef.GetSpacing())
    imRefImage.SetOrigin(imRef.GetOrigin())
    imRefImage.SetDirection(imRef.GetDirection())


    imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp)

    return imOutImage


def ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, target_space=(1., 1.)):
    
    assert imImage.GetSpacing() == imLabel.GetSpacing()
    assert imImage.GetSize() == imLabel.GetSize()


    spacing = imImage.GetSpacing()
    origin = imImage.GetOrigin()


    npimg = sitk.GetArrayFromImage(imImage)
    nplab = sitk.GetArrayFromImage(imLabel)
    t, z, y, x = npimg.shape
    
    if not os.path.exists('%s/%s'%(save_path, patient_name)):
        os.mkdir('%s/%s'%(save_path, patient_name))
    flag = 0
    for i in range(t):
        tmp_img = npimg[i]
        tmp_lab = nplab[i]
        
        if tmp_lab.max() == 0:
            continue

        
        print(i)
        flag += 1
        tmp_itkimg = sitk.GetImageFromArray(tmp_img)
        tmp_itkimg.SetSpacing(spacing[0:3])
        tmp_itkimg.SetOrigin(origin[0:3])
            
        tmp_itklab = sitk.GetImageFromArray(tmp_lab)
        tmp_itklab.SetSpacing(spacing[0:3])
        tmp_itklab.SetOrigin(origin[0:3])

        
        re_img = ResampleXYZAxis(tmp_itkimg, space=(target_space[0], target_space[1], spacing[2]))
        re_lab = ResampleFullImageToRef(tmp_itklab, re_img)

        
        sitk.WriteImage(re_img, '%s/%s/%s_%d.nii.gz'%(save_path, patient_name, name, flag))
        sitk.WriteImage(re_lab, '%s/%s/%s_gt_%d.nii.gz'%(save_path, patient_name, name, flag))
        
    return flag


if __name__ == '__main__':
    

    src_path = 'OpenDataset/Training/Labeled'
    tgt_path = 'dataset/Training'

    os.chdir(src_path)
    for name in os.listdir('.'):
        os.chdir(name)

        for i in os.listdir('.'):
            if 'gt' in i:
                tmp = i.split('_')
                img_name = tmp[0] + '_' + tmp[1]
                patient_name = tmp[0]
                
                img = sitk.ReadImage('%s.nii.gz'%img_name)
                lab = sitk.ReadImage('%s_gt.nii.gz'%img_name)
                
                flag = ResampleCMRImage(img, lab, tgt_path, patient_name, img_name, (1.2, 1.2))
                
                print(name, 'done', flag)

        os.chdir('..')






================================================
FILE: dataset/Testing_A.csv
================================================
name
A2L1N6
B0H7V0
B3F0V9
B5F8L9
B9H8N8
C4E9I1
C6E0F9
C6U4W8
D1H6U2
D9F5P1
E0J7L9
F0K4T6
G7S6V0
G9N5V9
H1N8S6
I2J6Z6
I5L3S2
J4J8Q3
K6N4N7
P3P9S5


================================================
FILE: dataset/Testing_B.csv
================================================
name
A2H5K9
A4A8V9
A5C2D2
A5D0G0
A6J0Y2
A7E4J0
B1G9J3
B3S2Z4
B4E1K1
B4S1Y2
B5L5Y4
B5T6V0
B7F5P0
C7L8Z8
C8I7P7
D1R0Y5
D3Q0W9
D6E9U8
D8O0W2
D9I8O7
E1L8Y4
E3L8U8
E4I9O7
E4O8P3
E5S7W7
E6J4N8
G1K1V3
G3M5S4
G7N8R7
G7Q2W0
H7K5U5
H8K2K7
I0I2J8
I6T4W8
I7W4Y8
J9L4S2
K5K6N1
K7L2Y6
K9N0W0
L7Y7Z2
L8N7Z0
M4T4V6
M6V2Y0
N9P5Z0
O4T6Y7
O9V8W5
P3R6Y5
P8W4Z0
Q3Q6R8
Y6Y9Z2


================================================
FILE: dataset/Testing_C.csv
================================================
name
A3H5R1
A5H1Q2
A8C5E9
A9F3T5
A9L7Y7
B0L3Y2
B2L0L2
B8P5Q9
C0N8P4
C7M6W0
C8J7L5
C8O0P2
D1H2O9
D2U0V0
D3K5Q2
D5G3W8
D6N7Q8
D7M8P9
D7T3V8
E3F2U7
E3F5U2
E6M6P2
E7L0N6
E9V9Z2
F0I6U8
F1K2S9
F5I1Z8
G8R0Z9
H2M9S1
H3R6S9
H7L8R8
H7P5Z4
I4R8V6
I6P4R0
I8Z0Z6
J6K4V3
K3R0Y7
K7O3Q0
L2V5Z0
L8N7P0
M6M9N1
O4O6U5
P3T5U1
Q1Q3T1
Q5V8W3
R1R6Y8
R2R7Z5
R6V5W3
R8V0Y4
V4W8Z5


================================================
FILE: dataset/Testing_D.csv
================================================
name
A1K2P5
A3P9V7
A4B9O6
A4K8R4
A4R4T0
A5P5W0
A5Q1W8
A6A8H0
A6B7Y4
A7F4G2
B3E2W8
B6I0T4
B9G4U2
C0L7V1
C5L0R0
D1L6T4
D1S5T8
E0J2Z9
E1L7M3
E4H7L4
E5J6L2
E6H0V9
F6J9L9
F9M1R2
G1J5K3
G4I7V2
G8K0M3
G8L0Z0
H6P7T1
I4L4V7
I8N8Y1
J6M5O2
K3P3Y6
K5L4S1
K5M7V5
K7N0R7
L5U7Y4
L6T2T5
L8M2U8
M2P5T8
N2O7U5
N7P3T8
N7W6Z8
N9Q4T8
O5U2U7
O7Q7U3
P8V0Y7
Q4W5Z8
R3V5W7
T2Z1Z9


================================================
FILE: dataset/Training_A.csv
================================================
name
A0S9V9
A1D9Z7
A1E9Q1
A2C0I1
A2N8V0
A3H1O5
A4B5U4
A4J4S4
A4U9V5
A6B5G9
A6D5F9
A7D9L8
A7M7P8
A7O4T6
B0I2Z0
B2C2Z7
B2D9M2
B2G5R2
B3O1S0
B3P3R1
B4O3V3
B8H5H6
B8J7R4
B9E0Q1
C0K1P0
C1K8P5
C2J0K3
C5M4S2
C6J5P1
C8P3S7
D0R0R9
D1J5P6
D1M1S6
D3D4Y5
D3F3O5
D3F9H9
D3O9U9
D4M3Q2
D4N6W6
D8E4F4
D9L1Z3
E0M3U7
E0O0S0
E9H1U4
E9V4Z8
F2H5S1
F3G5K5
G4S9U3
G5P4U3
G7I5V7
H1I3W0
H5N0P0
I7T3U1
J1T9Y1
J6K6P5
J6P5T8
J8R5W2
K4T7Y0
K5L2U3
K5P0Y1
L1Q9V8
L4Q2U3
M1R4S1
M2P1R1
N5S7Y1
N8N9U0
O0S9V7
P0S5Y0
P5R1Y4
Q0U0V5
Q3R9W7
R4Y1Z9
S1S3Z7
T2T9Z9
T9U9W2


================================================
FILE: dataset/Training_B.csv
================================================
name
A1D0Q7
A1O8Z3
A3B7E5
A5E0T8
A6M1Q7
A7G0P5
A8C9U8
A8E1F4
A9C5P4
A9E3G9
A9J5Q7
A9J8W7
B0N3W8
B2D9O2
B2F4K5
B3D0N1
B6D0U7
B9O1Q0
C0S7W0
C1G5Q0
C2L5P7
C2M6P8
C3I2K3
C4R8T7
C4S8W9
D0H9I4
D1L4Q9
D6H6O2
E3T0Z2
E4M2Q7
E4W8Z7
E5E6O8
E5F5V7
E9H2K7
E9L1W5
F0J2R8
F1F3I6
F4K3S1
F5I9Q2
F8N2S1
G0H4J3
G0I6P3
G1N6S7
G2J1M5
G2M7W4
G2O2S6
G4L8Z7
G8N2U5
G9L0O9
H0K3Q4
H1J5W8
H1M5Y6
H1W2Y1
H3U1Y1
H4I2T8
H6I0I6
H7I4J3
H7N4V9
I0J5U3
I2K2Y8
I6N3P3
J4J9W6
J9L6N9
K2S1U6
L1Q1Z5
L5Q6T7
M0P8U8
M4P7Q6
N1P8Q9
N7V9W9
O3R8Y5
P6U0Y0
P9S7W2
Q7V1Y5
W5Z4Z8


================================================
FILE: dataset/Training_C.csv
================================================
name


================================================
FILE: dataset/Training_D.csv
================================================
name


================================================
FILE: dataset/info.csv
================================================
External code,VendorName,Vendor,Centre,ED,ES
A0S9V9,Siemens,A,1,0,9
A1D0Q7,Philips,B,2,0,9
A1D9Z7,Siemens,A,1,22,11
A1E9Q1,Siemens,A,1,0,9
A1K2P5,Canon,D,5,33,11
A1O8Z3,Philips,B,3,23,10
A2C0I1,Siemens,A,1,0,7
A2E3W4,GE,C,4,0,0
A2H5K9,Philips,B,2,29,8
A2L1N6,Siemens,A,1,0,12
A2N8V0,Siemens,A,1,0,9
A3B7E5,Philips,B,2,29,12
A3H1O5,Siemens,A,1,0,12
A3H5R1,GE,C,4,24,6
A3P9V7,Canon,D,5,27,13
A4A8V9,Philips,B,3,0,10
A4B5U4,Siemens,A,1,0,10
A4B9O6,Canon,D,5,0,11
A4J4S4,Siemens,A,1,0,7
A4K8R4,Canon,D,5,24,11
A4R4T0,Canon,D,5,21,8
A4U9V5,Siemens,A,1,0,8
A5C2D2,Philips,B,3,24,9
A5D0G0,Philips,B,2,0,10
A5E0T8,Philips,B,3,24,8
A5H1Q2,GE,C,4,0,9
A5P5W0,Canon,D,5,33,10
A5Q1W8,Canon,D,5,27,10
A6A8H0,Canon,D,5,27,8
A6B5G9,Siemens,A,1,0,11
A6B7Y4,Canon,D,5,1,10
A6D5F9,Siemens,A,1,0,11
A6J0Y2,Philips,B,2,27,13
A6M1Q7,Philips,B,2,29,11
A7D9L8,Siemens,A,1,0,11
A7E4J0,Philips,B,3,1,12
A7F4G2,Canon,D,5,25,10
A7G0P5,Philips,B,2,28,9
A7M7P8,Siemens,A,1,0,9
A7O4T6,Siemens,A,1,0,10
A8C5E9,GE,C,4,24,9
A8C9H8,GE,C,4,0,0
A8C9U8,Philips,B,2,28,9
A8E1F4,Philips,B,3,24,9
A8I1U6,GE,C,4,0,0
A9C5P4,Philips,B,2,29,8
A9E3G9,Philips,B,3,23,8
A9F3T5,GE,C,4,24,8
A9J5Q7,Philips,B,3,24,7
A9J8W7,Philips,B,2,29,10
A9L7Y7,GE,C,4,24,5
B0H7V0,Siemens,A,1,0,9
B0I2Z0,Siemens,A,1,0,8
B0L3Y2,GE,C,4,24,10
B0N3W8,Philips,B,2,28,15
B1G9J3,Philips,B,2,29,10
B1K7U1,GE,C,4,0,0
B2C2Z7,Siemens,A,1,0,8
B2D9M2,Siemens,A,1,0,8
B2D9O2,Philips,B,2,29,13
B2F4K5,Philips,B,2,29,10
B2G5R2,Siemens,A,1,0,7
B2L0L2,GE,C,4,24,8
B3D0N1,Philips,B,3,24,8
B3E2W8,Canon,D,5,23,9
B3F0V9,Siemens,A,1,0,12
B3O1S0,Siemens,A,1,0,8
B3P3R1,Siemens,A,1,0,11
B3S2Z4,Philips,B,2,28,11
B4E1K1,Philips,B,3,0,9
B4I8Z7,GE,C,4,0,0
B4O3V3,Siemens,A,1,0,6
B4S1Y2,Philips,B,2,29,10
B5F8L9,Siemens,A,1,0,8
B5L5Y4,Philips,B,3,24,10
B5T6V0,Philips,B,3,0,10
B6D0U7,Philips,B,2,29,9
B6I0T4,Canon,D,5,33,10
B7F5P0,Philips,B,3,29,12
B8H5H6,Siemens,A,1,24,8
B8J7R4,Siemens,A,1,0,12
B8P5Q9,GE,C,4,23,9
B9E0Q1,Siemens,A,1,0,11
B9G4U2,Canon,D,5,28,8
B9H8N8,Siemens,A,1,0,10
B9O1Q0,Philips,B,2,29,11
C0K1P0,Siemens,A,1,0,9
C0L7V1,Canon,D,5,33,14
C0N8P4,GE,C,4,24,9
C0S7W0,Philips,B,3,24,7
C0W2Y5,GE,C,4,0,0
C1G5Q0,Philips,B,3,24,8
C1K8P5,Siemens,A,1,0,8
C2J0K3,Siemens,A,1,0,9
C2L5P7,Philips,B,2,28,12
C2M6P8,Philips,B,3,24,8
C3I2K3,Philips,B,2,29,11
C4E9I1,Siemens,A,1,0,8
C4K8M1,GE,C,4,0,0
C4R8T7,Philips,B,2,28,8
C4S8W9,Philips,B,3,24,8
C5L0R0,Canon,D,5,29,12
C5M4S2,Siemens,A,1,0,9
C5Q2Y5,GE,C,4,0,0
C6E0F9,Siemens,A,1,0,8
C6J5P1,Siemens,A,1,0,10
C6U4W8,Siemens,A,1,0,9
C7L8Z8,Philips,B,2,29,9
C7M6W0,GE,C,4,24,8
C8I7P7,Philips,B,2,29,11
C8J7L5,GE,C,4,24,7
C8O0P2,GE,C,4,24,8
C8P3S7,Siemens,A,1,0,8
C8V5W8,GE,C,4,0,0
D0E2W0,GE,C,4,0,0
D0H9I4,Philips,B,2,0,10
D0R0R9,Siemens,A,1,0,11
D1H2O9,GE,C,4,24,8
D1H6U2,Siemens,A,1,0,8
D1J5P6,Siemens,A,1,0,13
D1L4Q9,Philips,B,2,29,10
D1L6T4,Canon,D,5,27,10
D1M1S6,Siemens,A,1,0,10
D1R0Y5,Philips,B,3,24,9
D1S5T8,Canon,D,5,23,14
D2U0V0,GE,C,4,24,10
D3D4Y5,Siemens,A,1,0,9
D3F3O5,Siemens,A,1,24,10
D3F9H9,Siemens,A,1,24,8
D3K5Q2,GE,C,4,11,0
D3O9U9,Siemens,A,1,0,10
D3Q0W9,Philips,B,3,24,10
D4M3Q2,Siemens,A,1,0,9
D4N6W6,Siemens,A,1,0,10
D5G3W8,GE,C,4,23,8
D6E9U8,Philips,B,2,0,12
D6H6O2,Philips,B,2,29,9
D6N7Q8,GE,C,4,24,9
D7M8P9,GE,C,4,24,8
D7T3V8,GE,C,4,24,7
D8E4F4,Siemens,A,1,0,10
D8O0W2,Philips,B,3,24,8
D9F5P1,Siemens,A,1,0,10
D9I8O7,Philips,B,3,29,11
D9L1Z3,Siemens,A,1,0,12
E0J2Z9,Canon,D,5,25,11
E0J7L9,Siemens,A,1,0,9
E0M3U7,Siemens,A,1,0,9
E0O0S0,Siemens,A,1,0,11
E1L7M3,Canon,D,5,1,12
E1L8Y4,Philips,B,3,24,10
E3F2U7,GE,C,4,0,7
E3F5U2,GE,C,4,24,8
E3I4V1,GE,C,4,0,0
E3L8U8,Philips,B,3,29,9
E3T0Z2,Philips,B,2,29,10
E4H7L4,Canon,D,5,28,10
E4I9O7,Philips,B,3,29,10
E4M2Q7,Philips,B,3,0,8
E4O8P3,Philips,B,2,29,12
E4W8Z7,Philips,B,2,29,10
E5E6O8,Philips,B,2,29,10
E5F5V7,Philips,B,3,0,9
E5J6L2,Canon,D,5,29,11
E5S7W7,Philips,B,2,28,12
E6H0V9,Canon,D,5,33,12
E6J4N8,Philips,B,2,29,11
E6M6P2,GE,C,4,24,7
E7L0N6,GE,C,4,22,8
E9H1U4,Siemens,A,1,0,10
E9H2K7,Philips,B,2,29,11
E9L1W5,Philips,B,3,24,9
E9L4N2,GE,C,4,0,0
E9V4Z8,Siemens,A,1,0,8
E9V9Z2,GE,C,4,28,10
F0I6U8,GE,C,4,10,23
F0J2R8,Philips,B,2,29,9
F0K4T6,Siemens,A,1,0,8
F1F3I6,Philips,B,2,28,8
F1K2S9,GE,C,4,23,7
F2H5S1,Siemens,A,1,24,10
F3G5K5,Siemens,A,1,24,12
F3L0M1,GE,C,4,0,0
F4K3S1,Philips,B,3,24,8
F5I1Z8,GE,C,4,23,8
F5I9Q2,Philips,B,2,29,11
F6J9L9,Canon,D,5,29,10
F8N2S1,Philips,B,3,24,8
F9M1R2,Canon,D,5,29,11
G0H4J3,Philips,B,2,29,10
G0I6P3,Philips,B,3,29,12
G0I7Z6,GE,C,4,0,0
G1J5K3,Canon,D,5,24,13
G1K1V3,Philips,B,2,29,11
G1N6S7,Philips,B,2,0,12
G2J1M5,Philips,B,2,29,9
G2M7W4,Philips,B,3,24,9
G2O2S6,Philips,B,2,29,10
G3M5S4,Philips,B,3,29,10
G4I7V2,Canon,D,5,33,9
G4K8P3,GE,C,4,0,0
G4L8Z7,Philips,B,2,29,8
G4S9U3,Siemens,A,1,0,11
G4U3U8,GE,C,4,0,0
G5P4U3,Siemens,A,1,0,9
G6T0Z6,GE,C,4,0,0
G7I5V7,Siemens,A,1,0,10
G7N8R7,Philips,B,2,28,9
G7Q2W0,Philips,B,2,0,8
G7S6V0,Siemens,A,1,0,8
G8K0M3,Canon,D,5,25,10
G8L0Z0,Canon,D,5,3,10
G8N2U5,Philips,B,3,23,9
G8R0Z9,GE,C,4,21,8
G9L0O9,Philips,B,2,29,10
G9N5V9,Siemens,A,1,0,11
H0K3Q4,Philips,B,3,24,9
H1I3W0,Siemens,A,1,0,9
H1J5W8,Philips,B,2,3,15
H1M5Y6,Philips,B,2,29,8
H1N7P7,GE,C,4,0,0
H1N8S6,Siemens,A,1,0,10
H1W2Y1,Philips,B,2,29,10
H2M9S1,GE,C,4,23,8
H3R6S9,GE,C,4,24,9
H3U1Y1,Philips,B,3,24,8
H4I2T8,Philips,B,3,24,8
H5N0P0,Siemens,A,1,0,8
H6I0I6,Philips,B,2,28,10
H6P7T1,Canon,D,5,25,12
H7I4J3,Philips,B,3,24,8
H7K5U5,Philips,B,3,24,8
H7L8R8,GE,C,4,23,9
H7N4V9,Philips,B,2,0,12
H7P5Z4,GE,C,4,24,9
H8K2K7,Philips,B,3,23,7
H9J6L5,GE,C,4,0,0
I0I2J8,Philips,B,2,21,7
I0J5U3,Philips,B,2,28,10
I2J6Z6,Siemens,A,1,0,7
I2K2Y8,Philips,B,2,29,8
I4J8P4,GE,C,4,0,0
I4L4V7,Canon,D,5,22,9
I4R8V6,GE,C,4,24,8
I5L3S2,Siemens,A,1,0,10
I6N3P3,Philips,B,2,29,10
I6P4R0,GE,C,4,18,11
I6T4W8,Philips,B,2,28,12
I7T3U1,Siemens,A,1,0,13
I7W4Y8,Philips,B,2,27,10
I8N8Y1,Canon,D,5,33,14
I8Z0Z6,GE,C,4,24,8
J1T9Y1,Siemens,A,1,0,12
J2S5T1,GE,C,4,0,0
J4J8Q3,Siemens,A,1,0,8
J4J9W6,Philips,B,2,29,11
J6K4V3,GE,C,4,21,3
J6K6P5,Siemens,A,1,0,9
J6M5O2,Canon,D,5,33,11
J6P5T8,Siemens,A,1,0,9
J8R5W2,Siemens,A,1,0,9
J9L4S2,Philips,B,3,29,13
J9L6N9,Philips,B,2,29,9
K2S1U6,Philips,B,2,29,12
K3P3Y6,Canon,D,5,32,11
K3R0Y7,GE,C,4,28,9
K4T7Y0,Siemens,A,1,24,12
K5K6N1,Philips,B,3,24,9
K5L2U3,Siemens,A,1,0,7
K5L4S1,Canon,D,5,27,10
K5M7V5,Canon,D,5,23,10
K5P0Y1,Siemens,A,1,0,9
K6N4N7,Siemens,A,1,0,8
K7L2Y6,Philips,B,2,29,10
K7N0R7,Canon,D,5,29,9
K7O3Q0,GE,C,4,23,10
K9N0W0,Philips,B,2,29,12
L1Q1Z5,Philips,B,2,0,15
L1Q9V8,Siemens,A,1,0,10
L2V5Z0,GE,C,4,3,17
L4Q2U3,Siemens,A,1,0,8
L5Q6T7,Philips,B,2,28,11
L5U7Y4,Canon,D,5,33,10
L6T2T5,Canon,D,5,1,18
L7Y7Z2,Philips,B,2,0,11
L8M2U8,Canon,D,5,0,11
L8N7P0,GE,C,4,23,9
L8N7Z0,Philips,B,3,29,9
M0P8U8,Philips,B,3,0,9
M1R4S1,Siemens,A,1,0,9
M2P1R1,Siemens,A,1,0,10
M2P5T8,Canon,D,5,0,11
M4P7Q6,Philips,B,2,2,13
M4T4V6,Philips,B,2,29,12
M6M9N1,GE,C,4,24,11
M6V2Y0,Philips,B,3,0,9
M8N3Z3,GE,C,4,0,0
N1P8Q9,Philips,B,2,29,10
N1S7Z2,GE,C,4,0,0
N2O7U5,Canon,D,5,18,7
N5S7Y1,Siemens,A,1,0,9
N7P3T8,Canon,D,5,25,10
N7V9W9,Philips,B,3,24,7
N7W6Z8,Canon,D,5,23,11
N8N9U0,Siemens,A,1,0,11
N9P5Z0,Philips,B,3,0,13
N9Q4T8,Canon,D,5,32,10
O0S9V7,Siemens,A,1,0,9
O1O9Y6,GE,C,4,0,0
O3R8Y5,Philips,B,2,0,8
O4O6U5,GE,C,4,24,8
O4T6Y7,Philips,B,2,0,13
O5U2U7,Canon,D,5,1,9
O7Q7U3,Canon,D,5,29,11
O9V8W5,Philips,B,2,29,12
P0S5Y0,Siemens,A,1,0,10
P3P9S5,Siemens,A,1,0,10
P3R6Y5,Philips,B,3,24,9
P3T5U1,GE,C,4,27,9
P5R1Y4,Siemens,A,1,0,9
P6U0Y0,Philips,B,2,29,10
P8V0Y7,Canon,D,5,27,9
P8W4Z0,Philips,B,3,0,12
P9S7W2,Philips,B,3,23,10
Q0Q1Y4,GE,C,4,0,0
Q0U0V5,Siemens,A,1,0,10
Q1Q3T1,GE,C,4,28,10
Q3Q6R8,Philips,B,3,24,9
Q3R9W7,Siemens,A,1,0,10
Q4W5Z8,Canon,D,5,33,10
Q5V8W3,GE,C,4,24,10
Q7V1Y5,Philips,B,2,0,11
R1R6Y8,GE,C,4,23,10
R2R7Z5,GE,C,4,23,8
R3V5W7,Canon,D,5,27,13
R4Y1Z9,Siemens,A,1,24,7
R6V5W3,GE,C,4,28,10
R8V0Y4,GE,C,4,24,7
S1S3Z7,Siemens,A,1,0,9
T2T9Z9,Siemens,A,1,0,13
T2Z1Z9,Canon,D,5,29,9
T9U9W2,Siemens,A,1,0,10
V4W8Z5,GE,C,4,19,9
W5Z4Z8,Philips,B,2,29,11
Y6Y9Z2,Philips,B,3,29,9


================================================
FILE: dataset_domain.py
================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import SimpleITK as sitk
from skimage.measure import label, regionprops
import math
import pdb

class CMRDataset(Dataset):
    def __init__(self, dataset_dir, mode='train', domain='A', crop_size=256, scale=0.1, rotate=10, debug=False):

        self.mode = mode
        self.dataset_dir = dataset_dir
        self.crop_size = crop_size
        self.scale = scale
        self.rotate = rotate

        if self.mode == 'train':
            pre_face = 'Training'
            if 'C' in domain or 'D' in domain:
                print('No domain C or D in Training set')
                raise StandardError

        elif self.mode == 'test':
            pre_face = 'Testing'

        else:
            print('Wrong mode')
            raise StandardError
        if debug:
            # validation set is the smallest, need the shortest time for load data.
           pre_face = 'Testing'

        path = self.dataset_dir + pre_face + '/'
        print('start loading data')
        
        name_list = []

        if 'A' in domain:
            df = pd.read_csv(self.dataset_dir+pre_face+'_A.csv')
            name_list += np.array(df['name']).tolist()
        if 'B' in domain:
            df = pd.read_csv(self.dataset_dir+pre_face+'_B.csv')
            name_list += np.array(df['name']).tolist()
        if 'C' in domain:
            df = pd.read_csv(self.dataset_dir+pre_face+'_C.csv')
            name_list += np.array(df['name']).tolist()
        if 'D' in domain:
            df = pd.read_csv(self.dataset_dir+pre_face+'_D.csv')
            name_list += np.array(df['name']).tolist()



        
        img_list = []
        lab_list = []
        spacing_list = []

        for name in name_list:
            for name_idx in os.listdir(path+name):
                if 'gt' in name_idx:
                    continue
                else:
                    idx = name_idx.split('_')[2].split('.')[0]
                    
                    itk_img = sitk.ReadImage(path+name+'/%s_sa_%s.nii.gz'%(name, idx))
                    itk_lab = sitk.ReadImage(path+name+'/%s_sa_gt_%s.nii.gz'%(name, idx))
                    
                    spacing = np.array(itk_lab.GetSpacing()).tolist()
                    spacing_list.append(spacing[::-1])

                    assert itk_img.GetSize() == itk_lab.GetSize()
                    img, lab = self.preprocess(itk_img, itk_lab)

                    img_list.append(img)
                    lab_list.append(lab)

       
        self.img_slice_list = []
        self.lab_slice_list = []
        if self.mode == 'train':
            for i in range(len(img_list)):
                tmp_img = img_list[i]
                tmp_lab = lab_list[i]

                z, x, y = tmp_img.shape

                for j in range(z):
                    self.img_slice_list.append(tmp_img[j])
                    self.lab_slice_list.append(tmp_lab[j])

        else:
            self.img_slice_list = img_list
            self.lab_slice_list = lab_list
            self.spacing_list = spacing_list

        print('load done, length of dataset:', len(self.img_slice_list))
        
    def __len__(self):
        return len(self.img_slice_list)

    def preprocess(self, itk_img, itk_lab):
        
        img = sitk.GetArrayFromImage(itk_img)
        lab = sitk.GetArrayFromImage(itk_lab)

        max98 = np.percentile(img, 98)
        img = np.clip(img, 0, max98)
            
        z, y, x = img.shape
        if x < self.crop_size:
            diff = (self.crop_size + 10 - x) // 2
            img = np.pad(img, ((0,0), (0,0), (diff, diff)))
            lab = np.pad(lab, ((0,0), (0,0), (diff,diff)))
        if y < self.crop_size:
            diff = (self.crop_size + 10 -y) // 2
            img = np.pad(img, ((0,0), (diff, diff), (0,0)))
            lab = np.pad(lab, ((0,0), (diff, diff), (0,0)))

        img = img / max98
        tensor_img = torch.from_numpy(img).float()
        tensor_lab = torch.from_numpy(lab).long()

        return tensor_img, tensor_lab


    def __getitem__(self, idx):
        tensor_image = self.img_slice_list[idx]
        tensor_label = self.lab_slice_list[idx]
       
        if self.mode == 'train':
            tensor_image = tensor_image.unsqueeze(0).unsqueeze(0)
            tensor_label = tensor_label.unsqueeze(0).unsqueeze(0)
            
            # Gaussian Noise
            tensor_image += torch.randn(tensor_image.shape) * 0.02
            # Additive brightness
            rnd_bn = np.random.normal(0, 0.7)#0.03
            tensor_image += rnd_bn
            # gamma
            minm = tensor_image.min()
            rng = tensor_image.max() - minm
            gamma = np.random.uniform(0.5, 1.6)
            tensor_image = torch.pow((tensor_image-minm)/rng, gamma)*rng + minm

            tensor_image, tensor_label = self.random_zoom_rotate(tensor_image, tensor_label)
            tensor_image, tensor_label = self.randcrop(tensor_image, tensor_label)
        else:
            tensor_image, tensor_label = self.center_crop(tensor_image, tensor_label)
        
        assert tensor_image.shape == tensor_label.shape
        
        if self.mode == 'train':
            return tensor_image, tensor_label
        else:
            return tensor_image, tensor_label, np.array(self.spacing_list[idx])

    def randcrop(self, img, label):
        _, _, H, W = img.shape
        
        diff_H = H - self.crop_size
        diff_W = W - self.crop_size
        
        rand_x = np.random.randint(0, diff_H)
        rand_y = np.random.randint(0, diff_W)
        
        croped_img = img[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
        croped_lab = label[0, :, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]

        return croped_img, croped_lab


    def center_crop(self, img, label):
        D, H, W = img.shape
        
        diff_H = H - self.crop_size
        diff_W = W - self.crop_size
        
        rand_x = diff_H // 2
        rand_y = diff_W // 2

        croped_img = img[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]
        croped_lab = label[:, rand_x:rand_x+self.crop_size, rand_y:rand_y+self.crop_size]

        return croped_img, croped_lab

    def random_zoom_rotate(self, img, label):
        scale_x = np.random.random() * 2 * self.scale + (1 - self.scale)
        scale_y = np.random.random() * 2 * self.scale + (1 - self.scale)


        theta_scale = torch.tensor([[scale_x, 0, 0],
                                    [0, scale_y, 0],
                                    [0, 0, 1]]).float()
        angle = (float(np.random.randint(-self.rotate, self.rotate)) / 180.) * math.pi

        theta_rotate = torch.tensor( [  [math.cos(angle), -math.sin(angle), 0], 
                                        [math.sin(angle), math.cos(angle), 0], 
                                        ]).float()
        
    
        theta_rotate = theta_rotate.unsqueeze(0)
        grid = F.affine_grid(theta_rotate, img.size())
        img = F.grid_sample(img, grid, mode='bilinear')
        label = F.grid_sample(label.float(), grid, mode='nearest').long()
    
        return img, label




================================================
FILE: losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import pdb 

class DiceLoss(nn.Module):

    def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):
        super(DiceLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

        self.size_average = size_average
        self.reduce = reduce

    def forward(self, preds, targets):
        N = preds.size(0)
        C = preds.size(1)
        

        P = F.softmax(preds, dim=1)
        smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001)

        class_mask = torch.zeros(preds.shape).to(preds.device)
        class_mask.scatter_(1, targets, 1.) 

        ones = torch.ones(preds.shape).to(preds.device)
        P_ = ones - P 
        class_mask_ = ones - class_mask

        TP = P * class_mask
        FP = P * class_mask_
        FN = P_ * class_mask

        smooth = smooth.to(preds.device)
        self.alpha = FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) / ((FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) + FN.transpose(0, 1).reshape(C, -1).sum(dim=(1))) + smooth)
    
        self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8) 
        #print('alpha:', self.alpha)
        self.beta = 1 - self.alpha
        num = torch.sum(TP.transpose(0, 1).reshape(C, -1), dim=(1)).float()
        den = num + self.alpha * torch.sum(FP.transpose(0, 1).reshape(C, -1), dim=(1)).float() + self.beta * torch.sum(FN.transpose(0, 1).reshape(C, -1), dim=(1)).float()

        dice = num / (den + smooth)

        if not self.reduce:
            loss = torch.ones(C).to(dice.device) - dice
            return loss

        loss = 1 - dice
        loss = loss.sum()

        if self.size_average:
            loss /= C

        return loss

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()

        if alpha is None:
            self.alpha = torch.ones(class_num)
        else:
            self.alpha = alpha

        self.gamma = gamma
        self.size_average = size_average

    def forward(self, preds, targets):
        N = preds.size(0)
        C = preds.size(1)

        targets = targets.unsqueeze(1)
        P = F.softmax(preds, dim=1)
        log_P = F.log_softmax(preds, dim=1)

        class_mask = torch.zeros(preds.shape).to(preds.device)
        class_mask.scatter_(1, targets, 1.)
        
        if targets.size(1) == 1:
            # squeeze the chaneel for target
            targets = targets.squeeze(1)
        alpha = self.alpha[targets.data].to(preds.device)

        probs = (P * class_mask).sum(1)
        log_probs = (log_P * class_mask).sum(1)
        
        batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

if __name__ == '__main__':
    
    DL = DiceLoss()
    FL = FocalLoss(10)
    
    pred = torch.randn(2, 10, 128, 128)
    target = torch.zeros((2, 1, 128, 128)).long()

    dl_loss = DL(pred, target)
    fl_loss = FL(pred, target)

    print('2D:', dl_loss.item(), fl_loss.item())

    pred = torch.randn(2, 10, 64, 128, 128)
    target = torch.zeros(2, 1, 64, 128, 128).long()

    dl_loss = DL(pred, target)
    fl_loss = FL(pred, target)

    print('3D:', dl_loss.item(), fl_loss.item())

    


================================================
FILE: model/__init__.py
================================================


================================================
FILE: model/conv_trans_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import pdb



def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)

class depthwise_separable_conv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride)
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)

    def forward(self, x): 
        out = self.depthwise(x)
        out = self.pointwise(out)

        return out 

class Mlp(nn.Module):
    def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_ch = out_ch or in_ch
        hid_ch = hid_ch or in_ch

        self.fc1 = nn.Conv2d(in_ch, hid_ch, kernel_size=1)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hid_ch, out_ch, kernel_size=1)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x

class BasicBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes:
            self.shortcut = nn.Sequential(
                    nn.BatchNorm2d(inplanes),
                    self.relu, 
                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
                    )

    def forward(self, x): 
        residue = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out += self.shortcut(residue)

        return out

class BasicTransBlock(nn.Module):

    def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)

        self.attn = LinearAttention(in_ch, heads=heads, dim_head=in_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        
        self.bn2 = nn.BatchNorm2d(in_ch)
        self.relu = nn.ReLU(inplace=True)
        self.mlp = nn.Conv2d(in_ch, in_ch, kernel_size=1, bias=False)
        # conv1x1 has not difference with mlp in performance

    def forward(self, x):

        out = self.bn1(x)
        out, q_k_attn = self.attn(out)
        
        out = out + x
        residue = out

        out = self.bn2(out)
        out = self.relu(out)
        out = self.mlp(out)

        out += residue

        return out

class BasicTransDecoderBlock(nn.Module):

    def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
        super().__init__()

        self.bn_l = nn.BatchNorm2d(in_ch)
        self.bn_h = nn.BatchNorm2d(out_ch)

        self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        self.attn = LinearAttentionDecoder(in_ch, out_ch, heads=heads, dim_head=out_ch//heads, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.mlp = nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)

    def forward(self, x1, x2):

        residue = F.interpolate(self.conv_ch(x1), size=x2.shape[-2:], mode='bilinear', align_corners=True)
        #x1: low-res, x2: high-res
        x1 = self.bn_l(x1)
        x2 = self.bn_h(x2)

        out, q_k_attn = self.attn(x2, x1)
        
        out = out + residue
        residue = out

        out = self.bn2(out)
        out = self.relu(out)
        out = self.mlp(out)

        out += residue

        return out




########################################################################
# Transformer components

class LinearAttention(nn.Module):
    
    def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
        super().__init__()

        self.inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** (-0.5)
        self.dim_head = dim_head
        self.reduce_size = reduce_size
        self.projection = projection
        self.rel_pos = rel_pos
        
        # depthwise conv is slightly better than conv1x1
        #self.to_qkv = nn.Conv2d(dim, self.inner_dim*3, kernel_size=1, stride=1, padding=0, bias=True)
        #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)
       
        self.to_qkv = depthwise_separable_conv(dim, self.inner_dim*3)
        self.to_out = depthwise_separable_conv(self.inner_dim, dim)


        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        if self.rel_pos:
            # 2D input-independent relative position encoding is a little bit better than
            # 1D input-denpendent counterpart
            self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)
            #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)

    def forward(self, x):

        B, C, H, W = x.shape

        #B, inner_dim, H, W
        qkv = self.to_qkv(x)
        q, k, v = qkv.chunk(3, dim=1)

        if self.projection == 'interp' and H != self.reduce_size:
            k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))

        elif self.projection == 'maxpool' and H != self.reduce_size:
            k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))
        
        q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=H, w=W)
        k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))

        q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)
        
        if self.rel_pos:
            relative_position_bias = self.relative_position_encoding(H, W)
            q_k_attn += relative_position_bias
            #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, H, W, self.dim_head)
            #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w

        q_k_attn *= self.scale
        q_k_attn = F.softmax(q_k_attn, dim=-1)
        q_k_attn = self.attn_drop(q_k_attn)

        out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)
        out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head, heads=self.heads)

        out = self.to_out(out)
        out = self.proj_drop(out)

        return out, q_k_attn

class LinearAttentionDecoder(nn.Module):
    
    def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
        super().__init__()

        self.inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** (-0.5)
        self.dim_head = dim_head
        self.reduce_size = reduce_size
        self.projection = projection
        self.rel_pos = rel_pos
        
        # depthwise conv is slightly better than conv1x1
        #self.to_kv = nn.Conv2d(dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=True)
        #self.to_q = nn.Conv2d(dim, self.inner_dim, kernel_size=1, stride=1, padding=0, bias=True)
        #self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)
       
        self.to_kv = depthwise_separable_conv(in_dim, self.inner_dim*2)
        self.to_q = depthwise_separable_conv(out_dim, self.inner_dim)
        self.to_out = depthwise_separable_conv(self.inner_dim, out_dim)


        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        if self.rel_pos:
            self.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)
            #self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)

    def forward(self, q, x):

        B, C, H, W = x.shape # low-res feature shape
        BH, CH, HH, WH = q.shape # high-res feature shape

        k, v = self.to_kv(x).chunk(2, dim=1) #B, inner_dim, H, W
        q = self.to_q(q) #BH, inner_dim, HH, WH

        if self.projection == 'interp' and H != self.reduce_size:
            k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))

        elif self.projection == 'maxpool' and H != self.reduce_size:
            k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))
        
        q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=HH, w=WH)
        k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))

        q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)
        
        if self.rel_pos:
            relative_position_bias = self.relative_position_encoding(HH, WH)
            q_k_attn += relative_position_bias
            #rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, HH, WH, self.dim_head)
            #q_k_attn = q_k_attn + rel_attn_h + rel_attn_w
       
        q_k_attn *= self.scale
        q_k_attn = F.softmax(q_k_attn, dim=-1)
        q_k_attn = self.attn_drop(q_k_attn)

        out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)
        out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=HH, w=WH, dim_head=self.dim_head, heads=self.heads)

        out = self.to_out(out)
        out = self.proj_drop(out)

        return out, q_k_attn

class RelativePositionEmbedding(nn.Module):
    # input-dependent relative position
    def __init__(self, dim, shape):
        super().__init__()

        self.dim = dim
        self.shape = shape

        self.key_rel_w = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)
        self.key_rel_h = nn.Parameter(torch.randn((2*self.shape-1, dim))*0.02)

        coords = torch.arange(self.shape)
        relative_coords = coords[None, :] - coords[:, None] # h, h
        relative_coords += self.shape - 1 # shift to start from 0
        
        self.register_buffer('relative_position_index', relative_coords)



    def forward(self, q, Nh, H, W, dim_head):
        # q: B, Nh, HW, dim
        B, _, _, dim = q.shape

        # q: B, Nh, H, W, dim_head
        q = rearrange(q, 'b heads (h w) dim_head -> b heads h w dim_head', b=B, dim_head=dim_head, heads=Nh, h=H, w=W)

        rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, 'w')

        rel_logits_h = self.relative_logits_1d(q.permute(0, 1, 3, 2, 4), self.key_rel_h, 'h')

        return rel_logits_w, rel_logits_h

    def relative_logits_1d(self, q, rel_k, case):
        
        B, Nh, H, W, dim = q.shape

        rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) # B, Nh, H, W, 2*shape-1

        if W != self.shape:
            # self_relative_position_index origin shape: w, w
            # after repeat: W, w
            relative_index= torch.repeat_interleave(self.relative_position_index, W//self.shape, dim=0) # W, shape
        relative_index = relative_index.view(1, 1, 1, W, self.shape)
        relative_index = relative_index.repeat(B, Nh, H, 1, 1)

        rel_logits = torch.gather(rel_logits, 4, relative_index) # B, Nh, H, W, shape
        rel_logits = rel_logits.unsqueeze(3)
        rel_logits = rel_logits.repeat(1, 1, 1, self.shape, 1, 1)

        if case == 'w':
            rel_logits = rearrange(rel_logits, 'b heads H h W w -> b heads (H W) (h w)')

        elif case == 'h':
            rel_logits = rearrange(rel_logits, 'b heads W w H h -> b heads (H W) (h w)')

        return rel_logits




class RelativePositionBias(nn.Module):
    # input-independent relative position attention
    # As the number of parameters is smaller, so use 2D here
    # Borrowed some code from SwinTransformer: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
    def __init__(self, num_heads, h, w):
        super().__init__()
        self.num_heads = num_heads
        self.h = h
        self.w = w

        self.relative_position_bias_table = nn.Parameter(
                torch.randn((2*h-1) * (2*w-1), num_heads)*0.02)

        coords_h = torch.arange(self.h)
        coords_w = torch.arange(self.w)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w
        coords_flatten = torch.flatten(coords, 1) # 2, hw

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.h - 1
        relative_coords[:, :, 1] += self.w - 1
        relative_coords[:, :, 0] *= 2 * self.h - 1
        relative_position_index = relative_coords.sum(-1) # hw, hw
    
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, H, W):
        
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.h, self.w, self.h*self.w, -1) #h, w, hw, nH
        relative_position_bias_expand_h = torch.repeat_interleave(relative_position_bias, H//self.h, dim=0)
        relative_position_bias_expanded = torch.repeat_interleave(relative_position_bias_expand_h, W//self.w, dim=1) #HW, hw, nH
        
        relative_position_bias_expanded = relative_position_bias_expanded.view(H*W, self.h*self.w, self.num_heads).permute(2, 0, 1).contiguous().unsqueeze(0)

        return relative_position_bias_expanded


###########################################################################
# Unet Transformer building block

class down_block_trans(nn.Module):
    def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool=True, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):

        super().__init__()

        block_list = []

        if bottleneck:
            block = BottleneckBlock
        else:
            block = BasicBlock

        attn_block = BasicTransBlock

        if maxpool:
            block_list.append(nn.MaxPool2d(2))
            block_list.append(block(in_ch, out_ch, stride=1))
        else:
            block_list.append(block(in_ch, out_ch, stride=2))
        
        assert num_block > 0
        for i in range(num_block):
            block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
        self.blocks = nn.Sequential(*block_list)

        
    def forward(self, x):
        
        out = self.blocks(x)


        return out

class up_block_trans(nn.Module):
    def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):
        super().__init__()
 
        self.attn_decoder = BasicTransDecoderBlock(in_ch, out_ch, heads=heads, dim_head=dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        if bottleneck:
            block = BottleneckBlock
        else:
            block = BasicBlock
        attn_block = BasicTransBlock
        
        block_list = []

        for i in range(num_block):
            block_list.append(attn_block(out_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))

        block_list.append(block(2*out_ch, out_ch, stride=1))

        self.blocks = nn.Sequential(*block_list)

    def forward(self, x1, x2):
        # x1: low-res feature, x2: high-res feature
        out = self.attn_decoder(x1, x2)
        out = torch.cat([out, x2], dim=1)
        out = self.blocks(out)

        return out

class block_trans(nn.Module):
    def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='interp', rel_pos=True):

        super().__init__()

        block_list = []

        attn_block = BasicTransBlock

        assert num_block > 0
        for i in range(num_block):
            block_list.append(attn_block(in_ch, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
        self.blocks = nn.Sequential(*block_list)

        
    def forward(self, x):
        
        out = self.blocks(x)


        return out



================================================
FILE: model/resnet_utnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from .unet_utils import up_block
from .transunet import ResNetV2
from .conv_trans_utils import block_trans
import pdb




class ResNet_UTNet(nn.Module):
    def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', num_blocks=[1,2,4], projection='interp', num_heads=[4,4,4], attn_drop=0., proj_drop=0., rel_pos=True, block_units=(3,4,9), width_factor=1):
        
        super().__init__()
        self.resnet = ResNetV2(block_units, width_factor)


        if '0' in block_list:
            self.trans_0 = block_trans(64, num_blocks[-4], 64//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.trans_0 = nn.Identity()


        if '1' in block_list:
            self.trans_1 = block_trans(256, num_blocks[-3], 256//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        else:
            self.trans_1 = nn.Identity()

        if '2' in block_list:
            self.trans_2 =  block_trans(512, num_blocks[-2], 512//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        else:
            self.trans_2 = nn.Identity()

        if '3' in block_list:
            self.trans_3 = block_trans(1024, num_blocks[-1], 1024//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        else:
            self.trans_3 = nn.Identity()
        
        self.up1 = up_block(1024, 512, scale=(2,2), num_block=1)
        self.up2 = up_block(512, 256, scale=(2,2), num_block=1)
        self.up3 = up_block(256, 64, scale=(2,2), num_block=1)
        self.up4 = nn.UpsamplingBilinear2d(scale_factor=2)

        self.output = nn.Conv2d(64, num_class, kernel_size=3, padding=1, bias=True)

    def forward(self, x):
        if x.shape[1] == 1: 
            x = x.repeat(1, 3, 1, 1)
        x, features = self.resnet(x)

        out3 = self.trans_3(x)
        out2 = self.trans_2(features[0])
        out1 = self.trans_1(features[1])
        out0 = self.trans_0(features[2])

        out = self.up1(out3, out2)
        out = self.up2(out, out1)
        out = self.up3(out, out0)
        out = self.up4(out)

        out = self.output(out)

        return out
            




================================================
FILE: model/swin_unet.py
================================================
# code is borrowed from the original repo and fit into our training framework
# https://github.com/HuCaoFighting/Swin-Unet/tree/4375a8d6fa7d9c38184c5d3194db990a00a3e912


from __future__ import absolute_import

from __future__ import division

from __future__ import print_function


import torch

import torch.nn as nn

import torch.utils.checkpoint as checkpoint

from einops import rearrange

from timm.models.layers import DropPath, to_2tuple, trunc_normal_




import copy

import logging

import math


from os.path import join as pjoin

import numpy as np



from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm

from torch.nn.modules.utils import _pair

from scipy import ndimage





class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):

        super().__init__()

        out_features = out_features or in_features

        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)

        self.act = act_layer()

        self.fc2 = nn.Linear(hidden_features, out_features)

        self.drop = nn.Dropout(drop)



    def forward(self, x):

        x = self.fc1(x)

        x = self.act(x)

        x = self.drop(x)

        x = self.fc2(x)

        x = self.drop(x)

        return x





def window_partition(x, window_size):

    """

    Args:

        x: (B, H, W, C)

        window_size (int): window size



    Returns:

        windows: (num_windows*B, window_size, window_size, C)

    """

    B, H, W, C = x.shape

    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)

    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

    return windows





def window_reverse(windows, window_size, H, W):

    """

    Args:

        windows: (num_windows*B, window_size, window_size, C)

        window_size (int): Window size

        H (int): Height of image

        W (int): Width of image



    Returns:

        x: (B, H, W, C)

    """

    B = int(windows.shape[0] / (H * W / window_size / window_size))

    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)

    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)

    return x





class WindowAttention(nn.Module):

    r""" Window based multi-head self attention (W-MSA) module with relative position bias.

    It supports both of shifted and non-shifted window.



    Args:

        dim (int): Number of input channels.

        window_size (tuple[int]): The height and width of the window.

        num_heads (int): Number of attention heads.

        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True

        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set

        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0

        proj_drop (float, optional): Dropout ratio of output. Default: 0.0

    """



    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):



        super().__init__()

        self.dim = dim

        self.window_size = window_size  # Wh, Ww

        self.num_heads = num_heads

        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5



        # define a parameter table of relative position bias

        self.relative_position_bias_table = nn.Parameter(

            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH



        # get pair-wise relative position index for each token inside the window

        coords_h = torch.arange(self.window_size[0])

        coords_w = torch.arange(self.window_size[1])

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww

        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww

        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0

        relative_coords[:, :, 1] += self.window_size[1] - 1

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

        self.register_buffer("relative_position_index", relative_position_index)



        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)

        self.proj_drop = nn.Dropout(proj_drop)



        trunc_normal_(self.relative_position_bias_table, std=.02)

        self.softmax = nn.Softmax(dim=-1)



    def forward(self, x, mask=None):

        """

        Args:

            x: input features with shape of (num_windows*B, N, C)

            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None

        """

        B_, N, C = x.shape

        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)



        q = q * self.scale

        attn = (q @ k.transpose(-2, -1))



        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(

            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH

        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        attn = attn + relative_position_bias.unsqueeze(0)



        if mask is not None:

            nW = mask.shape[0]

            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

            attn = attn.view(-1, self.num_heads, N, N)

            attn = self.softmax(attn)

        else:

            attn = self.softmax(attn)



        attn = self.attn_drop(attn)



        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = self.proj(x)

        x = self.proj_drop(x)

        return x



    def extra_repr(self) -> str:

        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'



    def flops(self, N):

        # calculate flops for 1 window with token length of N

        flops = 0

        # qkv = self.qkv(x)

        flops += N * self.dim * 3 * self.dim

        # attn = (q @ k.transpose(-2, -1))

        flops += self.num_heads * N * (self.dim // self.num_heads) * N

        #  x = (attn @ v)

        flops += self.num_heads * N * N * (self.dim // self.num_heads)

        # x = self.proj(x)

        flops += N * self.dim * self.dim

        return flops





class SwinTransformerBlock(nn.Module):

    r""" Swin Transformer Block.



    Args:

        dim (int): Number of input channels.

        input_resolution (tuple[int]): Input resulotion.

        num_heads (int): Number of attention heads.

        window_size (int): Window size.

        shift_size (int): Shift size for SW-MSA.

        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.

        drop (float, optional): Dropout rate. Default: 0.0

        attn_drop (float, optional): Attention dropout rate. Default: 0.0

        drop_path (float, optional): Stochastic depth rate. Default: 0.0

        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU

        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm

    """



    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,

                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,

                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):

        super().__init__()

        self.dim = dim

        self.input_resolution = input_resolution

        self.num_heads = num_heads

        self.window_size = window_size

        self.shift_size = shift_size

        self.mlp_ratio = mlp_ratio

        if min(self.input_resolution) <= self.window_size:

            # if window size is larger than input resolution, we don't partition windows

            self.shift_size = 0

            self.window_size = min(self.input_resolution)

        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"



        self.norm1 = norm_layer(dim)

        self.attn = WindowAttention(

            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,

            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)



        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)



        if self.shift_size > 0:

            # calculate attention mask for SW-MSA

            H, W = self.input_resolution

            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1

            h_slices = (slice(0, -self.window_size),

                        slice(-self.window_size, -self.shift_size),

                        slice(-self.shift_size, None))

            w_slices = (slice(0, -self.window_size),

                        slice(-self.window_size, -self.shift_size),

                        slice(-self.shift_size, None))

            cnt = 0

            for h in h_slices:

                for w in w_slices:

                    img_mask[:, h, w, :] = cnt

                    cnt += 1



            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1

            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        else:

            attn_mask = None



        self.register_buffer("attn_mask", attn_mask)



    def forward(self, x):

        H, W = self.input_resolution

        B, L, C = x.shape

        assert L == H * W, "input feature has wrong size"



        shortcut = x

        x = self.norm1(x)

        x = x.view(B, H, W, C)



        # cyclic shift

        if self.shift_size > 0:

            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

        else:

            shifted_x = x



        # partition windows

        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C



        # W-MSA/SW-MSA

        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C



        # merge windows

        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C



        # reverse cyclic shift

        if self.shift_size > 0:

            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        else:

            x = shifted_x

        x = x.view(B, H * W, C)



        # FFN

        x = shortcut + self.drop_path(x)

        x = x + self.drop_path(self.mlp(self.norm2(x)))



        return x



    def extra_repr(self) -> str:

        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
            f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"



    def flops(self):

        flops = 0

        H, W = self.input_resolution

        # norm1

        flops += self.dim * H * W

        # W-MSA/SW-MSA

        nW = H * W / self.window_size / self.window_size

        flops += nW * self.attn.flops(self.window_size * self.window_size)

        # mlp

        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio

        # norm2

        flops += self.dim * H * W

        return flops





class PatchMerging(nn.Module):

    r""" Patch Merging Layer.



    Args:

        input_resolution (tuple[int]): Resolution of input feature.

        dim (int): Number of input channels.

        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm

    """



    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):

        super().__init__()

        self.input_resolution = input_resolution

        self.dim = dim

        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

        self.norm = norm_layer(4 * dim)



    def forward(self, x):

        """

        x: B, H*W, C

        """

        H, W = self.input_resolution

        B, L, C = x.shape

        assert L == H * W, "input feature has wrong size"

        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."



        x = x.view(B, H, W, C)



        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C

        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C

        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C

        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C

        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C



        x = self.norm(x)

        x = self.reduction(x)



        return x



    def extra_repr(self) -> str:

        return f"input_resolution={self.input_resolution}, dim={self.dim}"



    def flops(self):

        H, W = self.input_resolution

        flops = H * W * self.dim

        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim

        return flops



class PatchExpand(nn.Module):

    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):

        super().__init__()

        self.input_resolution = input_resolution

        self.dim = dim

        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()

        self.norm = norm_layer(dim // dim_scale)



    def forward(self, x):

        """

        x: B, H*W, C

        """

        H, W = self.input_resolution

        x = self.expand(x)

        B, L, C = x.shape

        assert L == H * W, "input feature has wrong size"



        x = x.view(B, H, W, C)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)

        x = x.view(B,-1,C//4)

        x= self.norm(x)



        return x



class FinalPatchExpand_X4(nn.Module):

    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):

        super().__init__()

        self.input_resolution = input_resolution

        self.dim = dim

        self.dim_scale = dim_scale

        self.expand = nn.Linear(dim, 16*dim, bias=False)

        self.output_dim = dim 

        self.norm = norm_layer(self.output_dim)



    def forward(self, x):

        """

        x: B, H*W, C

        """

        H, W = self.input_resolution

        x = self.expand(x)

        B, L, C = x.shape

        assert L == H * W, "input feature has wrong size"



        x = x.view(B, H, W, C)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))

        x = x.view(B,-1,self.output_dim)

        x= self.norm(x)



        return x



class BasicLayer(nn.Module):

    """ A basic Swin Transformer layer for one stage.



    Args:

        dim (int): Number of input channels.

        input_resolution (tuple[int]): Input resolution.

        depth (int): Number of blocks.

        num_heads (int): Number of attention heads.

        window_size (int): Local window size.

        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.

        drop (float, optional): Dropout rate. Default: 0.0

        attn_drop (float, optional): Attention dropout rate. Default: 0.0

        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0

        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None

        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.

    """



    def __init__(self, dim, input_resolution, depth, num_heads, window_size,

                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,

                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):



        super().__init__()

        self.dim = dim

        self.input_resolution = input_resolution

        self.depth = depth

        self.use_checkpoint = use_checkpoint



        # build blocks

        self.blocks = nn.ModuleList([

            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,

                                 num_heads=num_heads, window_size=window_size,

                                 shift_size=0 if (i % 2 == 0) else window_size // 2,

                                 mlp_ratio=mlp_ratio,

                                 qkv_bias=qkv_bias, qk_scale=qk_scale,

                                 drop=drop, attn_drop=attn_drop,

                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,

                                 norm_layer=norm_layer)

            for i in range(depth)])



        # patch merging layer

        if downsample is not None:

            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)

        else:

            self.downsample = None



    def forward(self, x):

        for blk in self.blocks:

            if self.use_checkpoint:

                x = checkpoint.checkpoint(blk, x)

            else:

                x = blk(x)

        if self.downsample is not None:

            x = self.downsample(x)

        return x



    def extra_repr(self) -> str:

        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"



    def flops(self):

        flops = 0

        for blk in self.blocks:

            flops += blk.flops()

        if self.downsample is not None:

            flops += self.downsample.flops()

        return flops



class BasicLayer_up(nn.Module):

    """ A basic Swin Transformer layer for one stage.



    Args:

        dim (int): Number of input channels.

        input_resolution (tuple[int]): Input resolution.

        depth (int): Number of blocks.

        num_heads (int): Number of attention heads.

        window_size (int): Local window size.

        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.

        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True

        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.

        drop (float, optional): Dropout rate. Default: 0.0

        attn_drop (float, optional): Attention dropout rate. Default: 0.0

        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0

        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm

        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None

        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.

    """



    def __init__(self, dim, input_resolution, depth, num_heads, window_size,

                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,

                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):



        super().__init__()

        self.dim = dim

        self.input_resolution = input_resolution

        self.depth = depth

        self.use_checkpoint = use_checkpoint



        # build blocks

        self.blocks = nn.ModuleList([

            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,

                                 num_heads=num_heads, window_size=window_size,

                                 shift_size=0 if (i % 2 == 0) else window_size // 2,

                                 mlp_ratio=mlp_ratio,

                                 qkv_bias=qkv_bias, qk_scale=qk_scale,

                                 drop=drop, attn_drop=attn_drop,

                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,

                                 norm_layer=norm_layer)

            for i in range(depth)])



        # patch merging layer

        if upsample is not None:

            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)

        else:

            self.upsample = None



    def forward(self, x):

        for blk in self.blocks:

            if self.use_checkpoint:

                x = checkpoint.checkpoint(blk, x)

            else:

                x = blk(x)

        if self.upsample is not None:

            x = self.upsample(x)

        return x



class PatchEmbed(nn.Module):

    r""" Image to Patch Embedding



    Args:

        img_size (int): Image size.  Default: 224.

        patch_size (int): Patch token size. Default: 4.

        in_chans (int): Number of input image channels. Default: 3.

        embed_dim (int): Number of linear projection output channels. Default: 96.

        norm_layer (nn.Module, optional): Normalization layer. Default: None

    """



    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):

        super().__init__()

        img_size = to_2tuple(img_size)

        patch_size = to_2tuple(patch_size)

        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]

        self.img_size = img_size

        self.patch_size = patch_size

        self.patches_resolution = patches_resolution

        self.num_patches = patches_resolution[0] * patches_resolution[1]



        self.in_chans = in_chans

        self.embed_dim = embed_dim



        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        if norm_layer is not None:

            self.norm = norm_layer(embed_dim)

        else:

            self.norm = None



    def forward(self, x):

        B, C, H, W = x.shape

        # FIXME look at relaxing size constraints

        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C

        if self.norm is not None:

            x = self.norm(x)

        return x



    def flops(self):

        Ho, Wo = self.patches_resolution

        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])

        if self.norm is not None:

            flops += Ho * Wo * self.embed_dim

        return flops





class SwinTransformerSys(nn.Module):

    r""" Swin Transformer

        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -

          https://arxiv.org/pdf/2103.14030



    Args:

        img_size (int | tuple(int)): Input image size. Default 224

        patch_size (int | tuple(int)): Patch size. Default: 4

        in_chans (int): Number of input image channels. Default: 3

        num_classes (int): Number of classes for classification head. Default: 1000

        embed_dim (int): Patch embedding dimension. Default: 96

        depths (tuple(int)): Depth of each Swin Transformer layer.

        num_heads (tuple(int)): Number of attention heads in different layers.

        window_size (int): Window size. Default: 7

        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4

        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True

        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None

        drop_rate (float): Dropout rate. Default: 0

        attn_drop_rate (float): Attention dropout rate. Default: 0

        drop_path_rate (float): Stochastic depth rate. Default: 0.1

        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.

        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False

        patch_norm (bool): If True, add normalization after patch embedding. Default: True

        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False

    """



    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,

                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],

                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,

                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,

                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,

                 use_checkpoint=False, final_upsample="expand_first", **kwargs):

        super().__init__()



        print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,

        depths_decoder,drop_path_rate,num_classes))



        self.num_classes = num_classes

        self.num_layers = len(depths)

        self.embed_dim = embed_dim

        self.ape = ape

        self.patch_norm = patch_norm

        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

        self.num_features_up = int(embed_dim * 2)

        self.mlp_ratio = mlp_ratio

        self.final_upsample = final_upsample



        # split image into non-overlapping patches

        self.patch_embed = PatchEmbed(

            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,

            norm_layer=norm_layer if self.patch_norm else None)

        num_patches = self.patch_embed.num_patches

        patches_resolution = self.patch_embed.patches_resolution

        self.patches_resolution = patches_resolution



        # absolute position embedding

        if self.ape:

            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

            trunc_normal_(self.absolute_pos_embed, std=.02)



        self.pos_drop = nn.Dropout(p=drop_rate)



        # stochastic depth

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule



        # build encoder and bottleneck layers

        self.layers = nn.ModuleList()

        for i_layer in range(self.num_layers):

            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),

                               input_resolution=(patches_resolution[0] // (2 ** i_layer),

                                                 patches_resolution[1] // (2 ** i_layer)),

                               depth=depths[i_layer],

                               num_heads=num_heads[i_layer],

                               window_size=window_size,

                               mlp_ratio=self.mlp_ratio,

                               qkv_bias=qkv_bias, qk_scale=qk_scale,

                               drop=drop_rate, attn_drop=attn_drop_rate,

                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],

                               norm_layer=norm_layer,

                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,

                               use_checkpoint=use_checkpoint)

            self.layers.append(layer)

        

        # build decoder layers

        self.layers_up = nn.ModuleList()

        self.concat_back_dim = nn.ModuleList()

        for i_layer in range(self.num_layers):

            concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),

            int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()

            if i_layer ==0 :

                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),

                patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)

            else:

                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),

                                input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),

                                                    patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),

                                depth=depths[(self.num_layers-1-i_layer)],

                                num_heads=num_heads[(self.num_layers-1-i_layer)],

                                window_size=window_size,

                                mlp_ratio=self.mlp_ratio,

                                qkv_bias=qkv_bias, qk_scale=qk_scale,

                                drop=drop_rate, attn_drop=attn_drop_rate,

                                drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],

                                norm_layer=norm_layer,

                                upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,

                                use_checkpoint=use_checkpoint)

            self.layers_up.append(layer_up)

            self.concat_back_dim.append(concat_linear)



        self.norm = norm_layer(self.num_features)

        self.norm_up= norm_layer(self.embed_dim)



        if self.final_upsample == "expand_first":

            print("---final upsample expand_first---")

            self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)

            self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)



        self.apply(self._init_weights)



    def _init_weights(self, m):

        if isinstance(m, nn.Linear):

            trunc_normal_(m.weight, std=.02)

            if isinstance(m, nn.Linear) and m.bias is not None:

                nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):

            nn.init.constant_(m.bias, 0)

            nn.init.constant_(m.weight, 1.0)



    @torch.jit.ignore

    def no_weight_decay(self):

        return {'absolute_pos_embed'}



    @torch.jit.ignore

    def no_weight_decay_keywords(self):

        return {'relative_position_bias_table'}



    #Encoder and Bottleneck

    def forward_features(self, x):

        x = self.patch_embed(x)

        if self.ape:

            x = x + self.absolute_pos_embed

        x = self.pos_drop(x)

        x_downsample = []



        for layer in self.layers:

            x_downsample.append(x)

            x = layer(x)



        x = self.norm(x)  # B L C

  

        return x, x_downsample



    #Dencoder and Skip connection

    def forward_up_features(self, x, x_downsample):

        for inx, layer_up in enumerate(self.layers_up):

            if inx == 0:

                x = layer_up(x)

            else:

                x = torch.cat([x,x_downsample[3-inx]],-1)

                x = self.concat_back_dim[inx](x)

                x = layer_up(x)



        x = self.norm_up(x)  # B L C

  

        return x



    def up_x4(self, x):

        H, W = self.patches_resolution

        B, L, C = x.shape

        assert L == H*W, "input features has wrong size"



        if self.final_upsample=="expand_first":

            x = self.up(x)

            x = x.view(B,4*H,4*W,-1)

            x = x.permute(0,3,1,2) #B,C,H,W

            x = self.output(x)

            

        return x



    def forward(self, x):

        x, x_downsample = self.forward_features(x)

        x = self.forward_up_features(x,x_downsample)

        x = self.up_x4(x)



        return x



    def flops(self):

        flops = 0

        flops += self.patch_embed.flops()

        for i, layer in enumerate(self.layers):

            flops += layer.flops()

        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)

        flops += self.num_features * self.num_classes

        return flops




logger = logging.getLogger(__name__)



class SwinUnet_config():
    def __init__(self):
        self.patch_size = 4
        self.in_chans = 3
        self.num_classes = 4
        self.embed_dim = 96
        self.depths = [2, 2, 6, 2]
        self.num_heads = [3, 6, 12, 24]
        self.window_size = 7
        self.mlp_ratio = 4.
        self.qkv_bias = True
        self.qk_scale = None
        self.drop_rate = 0.
        self.drop_path_rate = 0.1
        self.ape = False
        self.patch_norm = True
        self.use_checkpoint = False

class SwinUnet(nn.Module):

    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):

        super(SwinUnet, self).__init__()

        self.num_classes = num_classes

        self.zero_head = zero_head

        self.config = config



        self.swin_unet = SwinTransformerSys(img_size=img_size,

                                patch_size=config.patch_size,

                                in_chans=config.in_chans,

                                num_classes=self.num_classes,

                                embed_dim=config.embed_dim,

                                depths=config.depths,

                                num_heads=config.num_heads,

                                window_size=config.window_size,

                                mlp_ratio=config.mlp_ratio,

                                qkv_bias=config.qkv_bias,

                                qk_scale=config.qk_scale,

                                drop_rate=config.drop_rate,

                                drop_path_rate=config.drop_path_rate,

                                ape=config.ape,

                                patch_norm=config.patch_norm,

                                use_checkpoint=config.use_checkpoint)



    def forward(self, x):

        if x.size()[1] == 1:

            x = x.repeat(1,3,1,1)

        logits = self.swin_unet(x)

        return logits



    def load_from(self, pretrained_path):


        if pretrained_path is not None:

            print("pretrained_path:{}".format(pretrained_path))

            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            pretrained_dict = torch.load(pretrained_path, map_location=device)

            if "model"  not in pretrained_dict:

                print("---start load pretrained modle by splitting---")

                pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}

                for k in list(pretrained_dict.keys()):

                    if "output" in k:

                        print("delete key:{}".format(k))

                        del pretrained_dict[k]

                msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)

                # print(msg)

                return

            pretrained_dict = pretrained_dict['model']

            print("---start load pretrained modle of swin encoder---")



            model_dict = self.swin_unet.state_dict()

            full_dict = copy.deepcopy(pretrained_dict)

            for k, v in pretrained_dict.items():

                if "layers." in k:

                    current_layer_num = 3-int(k[7:8])

                    current_k = "layers_up." + str(current_layer_num) + k[8:]

                    full_dict.update({current_k:v})

            for k in list(full_dict.keys()):

                if k in model_dict:

                    if full_dict[k].shape != model_dict[k].shape:

                        print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape))

                        del full_dict[k]



            msg = self.swin_unet.load_state_dict(full_dict, strict=False)

            # print(msg)

        else:

            print("none pretrain")


================================================
FILE: model/transunet.py
================================================
# The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660ecda1d59
# coding=utf-8

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function


import ml_collections

import copy

import logging

import math

from collections import OrderedDict

from os.path import join as pjoin



import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np



from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm

from torch.nn.modules.utils import _pair

from scipy import ndimage






logger = logging.getLogger(__name__)





ATTENTION_Q = "MultiHeadDotProductAttention_1/query"

ATTENTION_K = "MultiHeadDotProductAttention_1/key"

ATTENTION_V = "MultiHeadDotProductAttention_1/value"

ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"

FC_0 = "MlpBlock_3/Dense_0"

FC_1 = "MlpBlock_3/Dense_1"

ATTENTION_NORM = "LayerNorm_0"

MLP_NORM = "LayerNorm_2"





def np2th(weights, conv=False):

    """Possibly convert HWIO to OIHW."""

    if conv:

        weights = weights.transpose([3, 2, 0, 1])

    return torch.from_numpy(weights)





def swish(x):

    return x * torch.sigmoid(x)





ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}





class Attention(nn.Module):

    def __init__(self, config, vis):

        super(Attention, self).__init__()

        self.vis = vis

        self.num_attention_heads = config.transformer["num_heads"]

        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)

        self.all_head_size = self.num_attention_heads * self.attention_head_size



        self.query = Linear(config.hidden_size, self.all_head_size)

        self.key = Linear(config.hidden_size, self.all_head_size)

        self.value = Linear(config.hidden_size, self.all_head_size)



        self.out = Linear(config.hidden_size, config.hidden_size)

        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])



        self.softmax = Softmax(dim=-1)



    def transpose_for_scores(self, x):

        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)

        x = x.view(*new_x_shape)

        return x.permute(0, 2, 1, 3)



    def forward(self, hidden_states):

        mixed_query_layer = self.query(hidden_states)

        mixed_key_layer = self.key(hidden_states)

        mixed_value_layer = self.value(hidden_states)



        query_layer = self.transpose_for_scores(mixed_query_layer)

        key_layer = self.transpose_for_scores(mixed_key_layer)

        value_layer = self.transpose_for_scores(mixed_value_layer)



        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        attention_probs = self.softmax(attention_scores)

        weights = attention_probs if self.vis else None

        attention_probs = self.attn_dropout(attention_probs)



        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

        context_layer = context_layer.view(*new_context_layer_shape)

        attention_output = self.out(context_layer)

        attention_output = self.proj_dropout(attention_output)

        return attention_output, weights





class Mlp(nn.Module):

    def __init__(self, config):

        super(Mlp, self).__init__()

        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])

        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)

        self.act_fn = ACT2FN["gelu"]

        self.dropout = Dropout(config.transformer["dropout_rate"])



        self._init_weights()



    def _init_weights(self):

        nn.init.xavier_uniform_(self.fc1.weight)

        nn.init.xavier_uniform_(self.fc2.weight)

        nn.init.normal_(self.fc1.bias, std=1e-6)

        nn.init.normal_(self.fc2.bias, std=1e-6)



    def forward(self, x):

        x = self.fc1(x)

        x = self.act_fn(x)

        x = self.dropout(x)

        x = self.fc2(x)

        x = self.dropout(x)

        return x





class Embeddings(nn.Module):

    """Construct the embeddings from patch, position embeddings.

    """

    def __init__(self, config, img_size, in_channels=3):

        super(Embeddings, self).__init__()

        self.hybrid = None

        self.config = config

        img_size = _pair(img_size)



        if config.patches.get("grid") is not None:   # ResNet

            grid_size = config.patches["grid"]

            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])

            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)

            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  

            self.hybrid = True

        else:

            patch_size = _pair(config.patches["size"])

            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

            self.hybrid = False



        if self.hybrid:

            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)

            in_channels = self.hybrid_model.width * 16

        self.patch_embeddings = Conv2d(in_channels=in_channels,

                                       out_channels=config.hidden_size,

                                       kernel_size=patch_size,

                                       stride=patch_size)

        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))



        self.dropout = Dropout(config.transformer["dropout_rate"])





    def forward(self, x):

        if self.hybrid:

            x, features = self.hybrid_model(x)

        else:

            features = None

        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))

        x = x.flatten(2)

        x = x.transpose(-1, -2)  # (B, n_patches, hidden)



        embeddings = x + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings, features





class Block(nn.Module):

    def __init__(self, config, vis):

        super(Block, self).__init__()

        self.hidden_size = config.hidden_size

        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)

        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)

        self.ffn = Mlp(config)

        self.attn = Attention(config, vis)



    def forward(self, x):

        h = x

        x = self.attention_norm(x)

        x, weights = self.attn(x)

        x = x + h



        h = x

        x = self.ffn_norm(x)

        x = self.ffn(x)

        x = x + h

        return x, weights



    def load_from(self, weights, n_block):

        ROOT = f"Transformer/encoderblock_{n_block}"

        with torch.no_grad():

            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()



            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)

            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)

            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)

            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)



            self.attn.query.weight.copy_(query_weight)

            self.attn.key.weight.copy_(key_weight)

            self.attn.value.weight.copy_(value_weight)

            self.attn.out.weight.copy_(out_weight)

            self.attn.query.bias.copy_(query_bias)

            self.attn.key.bias.copy_(key_bias)

            self.attn.value.bias.copy_(value_bias)

            self.attn.out.bias.copy_(out_bias)



            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()

            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()

            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()

            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()



            self.ffn.fc1.weight.copy_(mlp_weight_0)

            self.ffn.fc2.weight.copy_(mlp_weight_1)

            self.ffn.fc1.bias.copy_(mlp_bias_0)

            self.ffn.fc2.bias.copy_(mlp_bias_1)



            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))

            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))

            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))

            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))





class Encoder(nn.Module):

    def __init__(self, config, vis):

        super(Encoder, self).__init__()

        self.vis = vis

        self.layer = nn.ModuleList()

        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)

        for _ in range(config.transformer["num_layers"]):

            layer = Block(config, vis)

            self.layer.append(copy.deepcopy(layer))



    def forward(self, hidden_states):

        attn_weights = []

        for layer_block in self.layer:

            hidden_states, weights = layer_block(hidden_states)

            if self.vis:

                attn_weights.append(weights)

        encoded = self.encoder_norm(hidden_states)

        return encoded, attn_weights





class Transformer(nn.Module):

    def __init__(self, config, img_size, vis):

        super(Transformer, self).__init__()

        self.embeddings = Embeddings(config, img_size=img_size)

        self.encoder = Encoder(config, vis)



    def forward(self, input_ids):

        embedding_output, features = self.embeddings(input_ids)

        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)

        return encoded, attn_weights, features





class Conv2dReLU(nn.Sequential):

    def __init__(

            self,

            in_channels,

            out_channels,

            kernel_size,

            padding=0,

            stride=1,

            use_batchnorm=True,

    ):

        conv = nn.Conv2d(

            in_channels,

            out_channels,

            kernel_size,

            stride=stride,

            padding=padding,

            bias=not (use_batchnorm),

        )

        relu = nn.ReLU(inplace=True)



        bn = nn.BatchNorm2d(out_channels)



        super(Conv2dReLU, self).__init__(conv, bn, relu)





class DecoderBlock(nn.Module):

    def __init__(

            self,

            in_channels,

            out_channels,

            skip_channels=0,

            use_batchnorm=True,

    ):

        super().__init__()

        self.conv1 = Conv2dReLU(

            in_channels + skip_channels,

            out_channels,

            kernel_size=3,

            padding=1,

            use_batchnorm=use_batchnorm,

        )

        self.conv2 = Conv2dReLU(

            out_channels,

            out_channels,

            kernel_size=3,

            padding=1,

            use_batchnorm=use_batchnorm,

        )

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)



    def forward(self, x, skip=None):

        x = self.up(x)

        if skip is not None:

            x = torch.cat([x, skip], dim=1)

        x = self.conv1(x)

        x = self.conv2(x)

        return x





class SegmentationHead(nn.Sequential):



    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):

        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)

        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()

        super().__init__(conv2d, upsampling)





class DecoderCup(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.config = config

        head_channels = 512

        self.conv_more = Conv2dReLU(

            config.hidden_size,

            head_channels,

            kernel_size=3,

            padding=1,

            use_batchnorm=True,

        )

        decoder_channels = config.decoder_channels

        in_channels = [head_channels] + list(decoder_channels[:-1])

        out_channels = decoder_channels



        if self.config.n_skip != 0:

            skip_channels = self.config.skip_channels

            for i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skip

                skip_channels[3-i]=0



        else:

            skip_channels=[0,0,0,0]



        blocks = [

            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)

        ]

        self.blocks = nn.ModuleList(blocks)



    def forward(self, hidden_states, features=None):

        B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)

        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))

        x = hidden_states.permute(0, 2, 1)

        x = x.contiguous().view(B, hidden, h, w)

        x = self.conv_more(x)

        for i, decoder_block in enumerate(self.blocks):

            if features is not None:

                skip = features[i] if (i < self.config.n_skip) else None

            else:

                skip = None

            x = decoder_block(x, skip=skip)

        return x





class VisionTransformer(nn.Module):

    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):

        super(VisionTransformer, self).__init__()

        self.num_classes = num_classes

        self.zero_head = zero_head

        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)

        self.decoder = DecoderCup(config)

        self.segmentation_head = SegmentationHead(

            in_channels=config['decoder_channels'][-1],

            out_channels=config['n_classes'],

            kernel_size=3,

        )

        self.config = config



    def forward(self, x):

        if x.size()[1] == 1:

            x = x.repeat(1,3,1,1)

        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)

        x = self.decoder(x, features)

        logits = self.segmentation_head(x)

        return logits



    def load_from(self, weights):

        with torch.no_grad():



            res_weight = weights

            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))

            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))



            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))

            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))



            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])



            posemb_new = self.transformer.embeddings.position_embeddings

            if posemb.size() == posemb_new.size():

                self.transformer.embeddings.position_embeddings.copy_(posemb)

            elif posemb.size()[1]-1 == posemb_new.size()[1]:

                posemb = posemb[:, 1:]

                self.transformer.embeddings.position_embeddings.copy_(posemb)

            else:

                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))

                ntok_new = posemb_new.size(1)

                if self.classifier == "seg":

                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]

                gs_old = int(np.sqrt(len(posemb_grid)))

                gs_new = int(np.sqrt(ntok_new))

                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))

                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)

                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np

                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)

                posemb = posemb_grid

                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))



            # Encoder whole

            for bname, block in self.transformer.encoder.named_children():

                for uname, unit in block.named_children():

                    unit.load_from(weights, n_block=uname)



            if self.transformer.embeddings.hybrid:

                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))

                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)

                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)

                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)

                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)



                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():

                    for uname, unit in block.named_children():

                        unit.load_from(res_weight, n_block=bname, n_unit=uname)




def get_b16_config():

    """Returns the ViT-B/16 configuration."""

    config = ml_collections.ConfigDict()

    config.patches = ml_collections.ConfigDict({'size': (16, 16)})

    config.hidden_size = 768

    config.transformer = ml_collections.ConfigDict()

    config.transformer.mlp_dim = 3072

    config.transformer.num_heads = 12

    config.transformer.num_layers = 12

    config.transformer.attention_dropout_rate = 0.0

    config.transformer.dropout_rate = 0.1



    config.classifier = 'seg'

    config.representation_size = None

    config.resnet_pretrained_path = None

    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'

    config.patch_size = 16


    config.skip_channels = [0, 0, 0, 0]

    config.decoder_channels = (256, 128, 64, 16)

    config.n_classes = 2

    config.activation = 'softmax'

    return config





def get_testing():

    """Returns a minimal configuration for testing."""

    config = ml_collections.ConfigDict()

    config.patches = ml_collections.ConfigDict({'size': (16, 16)})

    config.hidden_size = 1

    config.transformer = ml_collections.ConfigDict()

    config.transformer.mlp_dim = 1

    config.transformer.num_heads = 1

    config.transformer.num_layers = 1

    config.transformer.attention_dropout_rate = 0.0

    config.transformer.dropout_rate = 0.1

    config.classifier = 'token'

    config.representation_size = None

    return config



def get_r50_b16_config():

    """Returns the Resnet50 + ViT-B/16 configuration."""

    config = get_b16_config()

    config.patches.grid = (16, 16)

    config.resnet = ml_collections.ConfigDict()

    config.resnet.num_layers = (3, 4, 9)

    config.resnet.width_factor = 1



    config.classifier = 'seg'

    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'

    config.decoder_channels = (256, 128, 64, 16)

    config.skip_channels = [512, 256, 64, 16]

    config.n_classes = 2

    config.n_skip = 3

    config.activation = 'softmax'



    return config





def get_b32_config():

    """Returns the ViT-B/32 configuration."""

    config = get_b16_config()

    config.patches.size = (32, 32)

    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'

    return config





def get_l16_config():

    """Returns the ViT-L/16 configuration."""

    config = ml_collections.ConfigDict()

    config.patches = ml_collections.ConfigDict({'size': (16, 16)})

    config.hidden_size = 1024

    config.transformer = ml_collections.ConfigDict()

    config.transformer.mlp_dim = 4096

    config.transformer.num_heads = 16

    config.transformer.num_layers = 24

    config.transformer.attention_dropout_rate = 0.0

    config.transformer.dropout_rate = 0.1

    config.representation_size = None



    # custom

    config.classifier = 'seg'

    config.resnet_pretrained_path = None

    config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'

    config.decoder_channels = (256, 128, 64, 16)

    config.n_classes = 2

    config.activation = 'softmax'

    return config





def get_r50_l16_config():

    """Returns the Resnet50 + ViT-L/16 configuration. customized """

    config = get_l16_config()

    config.patches.grid = (16, 16)

    config.resnet = ml_collections.ConfigDict()

    config.resnet.num_layers = (3, 4, 9)

    config.resnet.width_factor = 1



    config.classifier = 'seg'

    config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'

    config.decoder_channels = (256, 128, 64, 16)

    config.skip_channels = [512, 256, 64, 16]

    config.n_classes = 2

    config.activation = 'softmax'

    return config





def get_l32_config():

    """Returns the ViT-L/32 configuration."""

    config = get_l16_config()

    config.patches.size = (32, 32)

    return config





def get_h14_config():

    """Returns the ViT-L/16 configuration."""

    config = ml_collections.ConfigDict()

    config.patches = ml_collections.ConfigDict({'size': (14, 14)})

    config.hidden_size = 1280

    config.transformer = ml_collections.ConfigDict()

    config.transformer.mlp_dim = 5120

    config.transformer.num_heads = 16

    config.transformer.num_layers = 32

    config.transformer.attention_dropout_rate = 0.0

    config.transformer.dropout_rate = 0.1

    config.classifier = 'token'

    config.representation_size = None



    return config









CONFIGS = {

    'ViT-B_16': get_b16_config(),

    'ViT-B_32': get_b32_config(),

    'ViT-L_16': get_l16_config(),

    'ViT-L_32': get_l32_config(),

    'ViT-H_14': get_h14_config(),

    'R50-ViT-B_16': get_r50_b16_config(),

    'R50-ViT-L_16': get_r50_l16_config(),

    'testing': get_testing(),

}







def np2th(weights, conv=False):

    """Possibly convert HWIO to OIHW."""

    if conv:

        weights = weights.transpose([3, 2, 0, 1])

    return torch.from_numpy(weights)





class StdConv2d(nn.Conv2d):



    def forward(self, x):

        w = self.weight

        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)

        w = (w - m) / torch.sqrt(v + 1e-5)

        return F.conv2d(x, w, self.bias, self.stride, self.padding,

                        self.dilation, self.groups)





def conv3x3(cin, cout, stride=1, groups=1, bias=False):

    return StdConv2d(cin, cout, kernel_size=3, stride=stride,

                     padding=1, bias=bias, groups=groups)





def conv1x1(cin, cout, stride=1, bias=False):

    return StdConv2d(cin, cout, kernel_size=1, stride=stride,

                     padding=0, bias=bias)





class PreActBottleneck(nn.Module):

    """Pre-activation (v2) bottleneck block.

    """



    def __init__(self, cin, cout=None, cmid=None, stride=1):

        super().__init__()

        cout = cout or cin

        cmid = cmid or cout//4



        self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)

        self.conv1 = conv1x1(cin, cmid, bias=False)

        self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)

        self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!

        self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)

        self.conv3 = conv1x1(cmid, cout, bias=False)

        self.relu = nn.ReLU(inplace=True)



        if (stride != 1 or cin != cout):

            # Projection also with pre-activation according to paper.

            self.downsample = conv1x1(cin, cout, stride, bias=False)

            self.gn_proj = nn.GroupNorm(cout, cout)



    def forward(self, x):



        # Residual branch

        residual = x

        if hasattr(self, 'downsample'):

            residual = self.downsample(x)

            residual = self.gn_proj(residual)



        # Unit's branch

        y = self.relu(self.gn1(self.conv1(x)))

        y = self.relu(self.gn2(self.conv2(y)))

        y = self.gn3(self.conv3(y))



        y = self.relu(residual + y)

        return y



    def load_from(self, weights, n_block, n_unit):

        conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)

        conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)

        conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)



        gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])

        gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])



        gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])

        gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])



        gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])

        gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])



        self.conv1.weight.copy_(conv1_weight)

        self.conv2.weight.copy_(conv2_weight)

        self.conv3.weight.copy_(conv3_weight)



        self.gn1.weight.copy_(gn1_weight.view(-1))

        self.gn1.bias.copy_(gn1_bias.view(-1))



        self.gn2.weight.copy_(gn2_weight.view(-1))

        self.gn2.bias.copy_(gn2_bias.view(-1))



        self.gn3.weight.copy_(gn3_weight.view(-1))

        self.gn3.bias.copy_(gn3_bias.view(-1))



        if hasattr(self, 'downsample'):

            proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)

            proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])

            proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])



            self.downsample.weight.copy_(proj_conv_weight)

            self.gn_proj.weight.copy_(proj_gn_weight.view(-1))

            self.gn_proj.bias.copy_(proj_gn_bias.view(-1))



class ResNetV2(nn.Module):

    """Implementation of Pre-activation (v2) ResNet mode."""



    def __init__(self, block_units, width_factor):

        super().__init__()

        width = int(64 * width_factor)

        self.width = width



        self.root = nn.Sequential(OrderedDict([

            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),

            ('gn', nn.GroupNorm(32, width, eps=1e-6)),

            ('relu', nn.ReLU(inplace=True)),

            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))

        ]))



        self.body = nn.Sequential(OrderedDict([

            ('block1', nn.Sequential(OrderedDict(

                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +

                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],

                ))),

            ('block2', nn.Sequential(OrderedDict(

                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +

                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],

                ))),

            ('block3', nn.Sequential(OrderedDict(

                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +

                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],

                ))),

        ]))



    def forward(self, x):

        features = []

        b, c, in_size, _ = x.size()

        x = self.root(x)

        features.append(x)

        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)

        for i in range(len(self.body)-1):

            x = self.body[i](x)

            right_size = int(in_size / 4 / (i+1))

            if x.size()[2] != right_size:

                pad = right_size - x.size()[2]

                assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)

                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)

                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]

            else:

                feat = x

            features.append(feat)

        x = self.body[-1](x)

        return x, features[::-1]


================================================
FILE: model/unet_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb 

class DoubleConv(nn.Module):

    """(convolution => [BN] => ReLU) * 2"""



    def __init__(self, in_channels, out_channels, mid_channels=None):

        super().__init__()

        if not mid_channels:

            mid_channels = out_channels

        self.double_conv = nn.Sequential(

            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),

            nn.BatchNorm2d(mid_channels),

            nn.ReLU(inplace=True),

            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),

            nn.BatchNorm2d(out_channels),

            nn.ReLU(inplace=True)

        )



    def forward(self, x):

        return self.double_conv(x)





class Down(nn.Module):

    """Downscaling with maxpool then double conv"""



    def __init__(self, in_channels, out_channels):

        super().__init__()

        self.maxpool_conv = nn.Sequential(

            nn.MaxPool2d(2),

            DoubleConv(in_channels, out_channels)

        )



    def forward(self, x):

        return self.maxpool_conv(x)





class Up(nn.Module):

    """Upscaling then double conv"""



    def __init__(self, in_channels, out_channels, bilinear=True):

        super().__init__()



        # if bilinear, use the normal convolutions to reduce the number of channels

        if bilinear:

            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

        else:

            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)

            self.conv = DoubleConv(in_channels, out_channels)





    def forward(self, x1, x2):

        x1 = self.up(x1)

        # input is CHW

        diffY = x2.size()[2] - x1.size()[2]

        diffX = x2.size()[3] - x1.size()[3]



        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,

                        diffY // 2, diffY - diffY // 2])

        # if you have padding issues, see

        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a

        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)





class OutConv(nn.Module):

    def __init__(self, in_channels, out_channels):

        super(OutConv, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)



    def forward(self, x):

        return self.conv(x)


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
    


class BasicBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes:
            self.shortcut = nn.Sequential(
                    nn.BatchNorm2d(inplanes),
                    self.relu,
                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
                    )

    def forward(self, x): 
        residue = x 

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out += self.shortcut(residue)

        return out 

class BottleneckBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = conv1x1(inplanes, planes//4, stride=1)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(planes//4, planes//4, stride=stride)
        self.bn2 = nn.BatchNorm2d(planes//4)

        self.conv3 = conv1x1(planes//4, planes, stride=1)
        self.bn3 = nn.BatchNorm2d(planes//4)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes:
            self.shortcut = nn.Sequential(
                    nn.BatchNorm2d(inplanes),
                    self.relu,
                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
                    )

    def forward(self, x):
        residue = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        out += self.shortcut(residue)

        return out



class inconv(nn.Module):
    def __init__(self, in_ch, out_ch, bottleneck=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if bottleneck:
            self.conv2 = BottleneckBlock(out_ch, out_ch)
        else:
            self.conv2 = BasicBlock(out_ch, out_ch)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)

        return out


class down_block(nn.Module):
    def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, pool=True):
        super().__init__()

        block_list = []

        if bottleneck:
            block = BottleneckBlock
        else:
            block = BasicBlock


        if pool:
            block_list.append(nn.MaxPool2d(scale))
            block_list.append(block(in_ch, out_ch))
        else:
            block_list.append(block(in_ch, out_ch, stride=2))

        for i in range(num_block-1):
            block_list.append(block(out_ch, out_ch, stride=1))

        self.conv = nn.Sequential(*block_list)

    def forward(self, x):
        return self.conv(x)




class up_block(nn.Module):
    def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=False):
        super().__init__()
        self.scale=scale

        self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1)

        if bottleneck:
            block = BottleneckBlock
        else:
            block = BasicBlock


        block_list = []
        block_list.append(block(2*out_ch, out_ch))

        for i in range(num_block-1):
            block_list.append(block(out_ch, out_ch))

        self.conv = nn.Sequential(*block_list)

    def forward(self, x1, x2):
        x1 = F.interpolate(x1, scale_factor=self.scale, mode='bilinear', align_corners=True)
        x1 = self.conv_ch(x1)

        out = torch.cat([x2, x1], dim=1)
        out = self.conv(out)

        return out





================================================
FILE: model/utnet.py
================================================
import torch
import torch.nn as nn

from .unet_utils import up_block, down_block
from .conv_trans_utils import *

import pdb



class UTNet(nn.Module):
    
    def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):
        super().__init__()

        self.aux_loss = aux_loss
        self.inc = [BasicBlock(in_chan, base_chan)]
        if '0' in block_list:
            self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
            self.up4 = up_block_trans(2*base_chan, base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-4], dim_head=base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        
        else:
            self.inc.append(BasicBlock(base_chan, base_chan))
            self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)
        self.inc = nn.Sequential(*self.inc)


        if '1' in block_list:
            self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
            self.up3 = up_block_trans(4*base_chan, 2*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-3], dim_head=2*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)
            self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)

        if '2' in block_list:
            self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
            self.up2 = up_block_trans(8*base_chan, 4*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-2], dim_head=4*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        else:
            self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)
            self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)

        if '3' in block_list:
            self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
            self.up1 = up_block_trans(16*base_chan, 8*base_chan, num_block=0, bottleneck=bottleneck, heads=num_heads[-1], dim_head=8*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)

        else:
            self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)
            self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)

        if '4' in block_list:
            self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)


        self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)

        if aux_loss:
            self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)
            self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)
            self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)
            


    def forward(self, x):
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        if self.aux_loss:
            out = self.up1(x5, x4)
            out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up2(out, x3)
            out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up3(out, x2)
            out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up4(out, x1)
            out = self.outc(out)

            return out, out3, out2, out1

        else:
            out = self.up1(x5, x4)
            out = self.up2(out, x3)
            out = self.up3(out, x2)

            out = self.up4(out, x1)
            out = self.outc(out)

            return out



        



class UTNet_Encoderonly(nn.Module):
    
    def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, block_list='234', num_blocks=[1, 2, 4], projection='interp', num_heads=[2,4,8], attn_drop=0., proj_drop=0., bottleneck=False, maxpool=True, rel_pos=True, aux_loss=False):
        super().__init__()

        self.aux_loss = aux_loss
        
        self.inc = [BasicBlock(in_chan, base_chan)]
        if '0' in block_list:
            self.inc.append(BasicTransBlock(base_chan, heads=num_heads[-5], dim_head=base_chan//num_heads[-5], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos))
        
        else:
            self.inc.append(BasicBlock(base_chan, base_chan))
        self.inc = nn.Sequential(*self.inc)


        if '1' in block_list:
            self.down1 = down_block_trans(base_chan, 2*base_chan, num_block=num_blocks[-4], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-4], dim_head=2*base_chan//num_heads[-4], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down1 = down_block(base_chan, 2*base_chan, (2,2), num_block=2)


        if '2' in block_list:
            self.down2 = down_block_trans(2*base_chan, 4*base_chan, num_block=num_blocks[-3], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-3], dim_head=4*base_chan//num_heads[-3], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down2 = down_block(2*base_chan, 4*base_chan, (2, 2), num_block=2)


        if '3' in block_list:
            self.down3 = down_block_trans(4*base_chan, 8*base_chan, num_block=num_blocks[-2], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-2], dim_head=8*base_chan//num_heads[-2], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down3 = down_block(4*base_chan, 8*base_chan, (2,2), num_block=2)

        if '4' in block_list:
            self.down4 = down_block_trans(8*base_chan, 16*base_chan, num_block=num_blocks[-1], bottleneck=bottleneck, maxpool=maxpool, heads=num_heads[-1], dim_head=16*base_chan//num_heads[-1], attn_drop=attn_drop, proj_drop=proj_drop, reduce_size=reduce_size, projection=projection, rel_pos=rel_pos)
        else:
            self.down4 = down_block(8*base_chan, 16*base_chan, (2,2), num_block=2)


        self.up1 = up_block(16*base_chan, 8*base_chan, scale=(2,2), num_block=2)
        self.up2 = up_block(8*base_chan, 4*base_chan, scale=(2,2), num_block=2)
        self.up3 = up_block(4*base_chan, 2*base_chan, scale=(2,2), num_block=2)
        self.up4 = up_block(2*base_chan, base_chan, scale=(2,2), num_block=2)

        self.outc = nn.Conv2d(base_chan, num_classes, kernel_size=1, bias=True)

        if aux_loss:
            self.out1 = nn.Conv2d(8*base_chan, num_classes, kernel_size=1, bias=True)
            self.out2 = nn.Conv2d(4*base_chan, num_classes, kernel_size=1, bias=True)
            self.out3 = nn.Conv2d(2*base_chan, num_classes, kernel_size=1, bias=True)
            


    def forward(self, x):
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        if self.aux_loss:
            out = self.up1(x5, x4)
            out1 = F.interpolate(self.out1(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up2(out, x3)
            out2 = F.interpolate(self.out2(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up3(out, x2)
            out3 = F.interpolate(self.out3(out), size=x.shape[-2:], mode='bilinear', align_corners=True)

            out = self.up4(out, x1)
            out = self.outc(out)

            return out, out3, out2, out1

        else:
            out = self.up1(x5, x4)
            out = self.up2(out, x3)
            out = self.up3(out, x2)

            out = self.up4(out, x1)
            out = self.outc(out)

            return out




================================================
FILE: train_deep.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import numpy as np
from model.utnet import UTNet, UTNet_Encoderonly

from dataset_domain import CMRDataset

from torch.utils import data
from losses import DiceLoss
from utils.utils import *
from utils import metrics
from optparse import OptionParser
import SimpleITK as sitk

from torch.utils.tensorboard import SummaryWriter
import time
import math
import os
import sys
import pdb
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
DEBUG = False

def train_net(net, options):
    
    data_path = options.data_path

    trainset = CMRDataset(data_path, mode='train', domain=options.domain, debug=DEBUG, scale=options.scale, rotate=options.rotate, crop_size=options.crop_size)
    trainLoader = data.DataLoader(trainset, batch_size=options.batch_size, shuffle=True, num_workers=16)

    testset_A = CMRDataset(data_path, mode='test', domain='A', debug=DEBUG, crop_size=options.crop_size)
    testLoader_A = data.DataLoader(testset_A, batch_size=1, shuffle=False, num_workers=2)
    testset_B = CMRDataset(data_path, mode='test', domain='B', debug=DEBUG, crop_size=options.crop_size)
    testLoader_B = data.DataLoader(testset_B, batch_size=1, shuffle=False, num_workers=2)
    testset_C = CMRDataset(data_path, mode='test', domain='C', debug=DEBUG, crop_size=options.crop_size)
    testLoader_C = data.DataLoader(testset_C, batch_size=1, shuffle=False, num_workers=2)
    testset_D = CMRDataset(data_path, mode='test', domain='D', debug=DEBUG, crop_size=options.crop_size)
    testLoader_D = data.DataLoader(testset_D, batch_size=1, shuffle=False, num_workers=2)





    writer = SummaryWriter(options.log_path + options.unique_name)

    optimizer = optim.SGD(net.parameters(), lr=options.lr, momentum=0.9, weight_decay=options.weight_decay)

    criterion = nn.CrossEntropyLoss(weight=torch.tensor(options.weight).cuda())
    criterion_dl = DiceLoss()


    best_dice = 0
    for epoch in range(options.epochs):
        print('Starting epoch {}/{}'.format(epoch+1, options.epochs))
        epoch_loss = 0

        exp_scheduler = exp_lr_scheduler_with_warmup(optimizer, init_lr=options.lr, epoch=epoch, warmup_epoch=5, max_epoch=options.epochs)

        print('current lr:', exp_scheduler)

        for i, (img, label) in enumerate(trainLoader, 0):

            img = img.cuda()
            label = label.cuda()

            end = time.time()
            net.train()

            optimizer.zero_grad()
            
            result = net(img)
            
            loss = 0
            
            if isinstance(result, tuple) or isinstance(result, list):
                for j in range(len(result)):
                    loss += options.aux_weight[j] * (criterion(result[j], label.squeeze(1)) + criterion_dl(result[j], label))
            else:
                loss = criterion(result, label.squeeze(1)) + criterion_dl(result, label)


            loss.backward()
            optimizer.step()


            epoch_loss += loss.item()
            batch_time = time.time() - end
            print('batch loss: %.5f, batch_time:%.5f'%(loss.item(), batch_time))
        print('[epoch %d] epoch loss: %.5f'%(epoch+1, epoch_loss/(i+1)))

        writer.add_scalar('Train/Loss', epoch_loss/(i+1), epoch+1)
        writer.add_scalar('LR', exp_scheduler, epoch+1)

        if os.path.isdir('%s%s/'%(options.cp_path, options.unique_name)):
            pass
        else:
            os.mkdir('%s%s/'%(options.cp_path, options.unique_name))

        if epoch % 20 == 0 or epoch > options.epochs-10:
            torch.save(net.state_dict(), '%s%s/CP%d.pth'%(options.cp_path, options.unique_name, epoch))
        
        if (epoch+1) >90 or (epoch+1) % 10 == 0:
            dice_list_A, ASD_list_A, HD_list_A = validation(net, testLoader_A, options)
            log_evaluation_result(writer, dice_list_A, ASD_list_A, HD_list_A, 'A', epoch)
            
            dice_list_B, ASD_list_B, HD_list_B = validation(net, testLoader_B, options)
            log_evaluation_result(writer, dice_list_B, ASD_list_B, HD_list_B, 'B', epoch)

            dice_list_C, ASD_list_C, HD_list_C = validation(net, testLoader_C, options)
            log_evaluation_result(writer, dice_list_C, ASD_list_C, HD_list_C, 'C', epoch)

            dice_list_D, ASD_list_D, HD_list_D = validation(net, testLoader_D, options)
            log_evaluation_result(writer, dice_list_D, ASD_list_D, HD_list_D, 'D', epoch)


            AVG_dice_list = 20 * dice_list_A + 50 * dice_list_B + 50 * dice_list_C + 50 * dice_list_D
            AVG_dice_list /= 170

            AVG_ASD_list = 20 * ASD_list_A + 50 * ASD_list_B + 50 * ASD_list_C + 50 * ASD_list_D
            AVG_ASD_list /= 170

            AVG_HD_list = 20 * HD_list_A + 50 * HD_list_B + 50 * HD_list_C + 50 * HD_list_D
            AVG_HD_list /= 170

            log_evaluation_result(writer, AVG_dice_list, AVG_ASD_list, AVG_HD_list, 'mean', epoch)



            if dice_list_A.mean() >= best_dice:
                best_dice = dice_list_A.mean()
                torch.save(net.state_dict(), '%s%s/best.pth'%(options.cp_path, options.unique_name))

            print('save done')
            print('dice: %.5f/best dice: %.5f'%(dice_list_A.mean(), best_dice))


def validation(net, test_loader, options):

    net.eval()
    
    dice_list = np.zeros(3)
    ASD_list = np.zeros(3)
    HD_list = np.zeros(3)

    counter = 0
    with torch.no_grad():
        for i, (data, label, spacing) in enumerate(test_loader):
            
            inputs, labels = data.float().cuda(), label.long().cuda()
            inputs = inputs.permute(1, 0, 2, 3)
            labels = labels.permute(1, 0, 2, 3)
    
            pred = net(inputs)
            if options.model == 'FCN_Res50' or options.model == 'FCN_Res101':
                pred = pred['out']
            elif isinstance(pred, tuple):
                pred = pred[0]
            pred = F.softmax(pred, dim=1)
    
            _, label_pred = torch.max(pred, dim=1)
            
            tmp_ASD_list, tmp_HD_list = cal_distance(label_pred, labels, spacing)
            ASD_list += np.clip(np.nan_to_num(tmp_ASD_list, nan=500), 0, 500)
            HD_list += np.clip(np.nan_to_num(tmp_HD_list, nan=500), 0, 500)
            
            label_pred = label_pred.view(-1, 1)
            label_true = labels.view(-1, 1)

            dice, _, _ = cal_dice(label_pred, label_true, 4)
        
            dice_list += dice.cpu().numpy()[1:]
            counter += 1
    dice_list /= counter
    avg_dice = dice_list.mean()
    ASD_list /= counter
    HD_list /= counter

    return dice_list , ASD_list, HD_list




def cal_distance(label_pred, label_true, spacing):
    label_pred = label_pred.squeeze(1).cpu().numpy()
    label_true = label_true.squeeze(1).cpu().numpy()
    spacing = spacing.numpy()[0]

    ASD_list = np.zeros(3)
    HD_list = np.zeros(3)

    for i in range(3):
        tmp_surface = metrics.compute_surface_distances(label_true==(i+1), label_pred==(i+1), spacing)
        dis_gt_to_pred, dis_pred_to_gt = metrics.compute_average_surface_distance(tmp_surface)
        ASD_list[i] = (dis_gt_to_pred + dis_pred_to_gt) / 2

        HD = metrics.compute_robust_hausdorff(tmp_surface, 100)
        HD_list[i] = HD

    return ASD_list, HD_list




if __name__ == '__main__':
    parser = OptionParser()
    def get_comma_separated_int_args(option, opt, value, parser):
        value_list = value.split(',')
        value_list = [int(i) for i in value_list]
        setattr(parser.values, option.dest, value_list)

    parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', help='number of epochs')
    parser.add_option('-b', '--batch_size', dest='batch_size', default=32, type='int', help='batch size')
    parser.add_option('-l', '--learning-rate', dest='lr', default=0.05, type='float', help='learning rate')
    parser.add_option('-c', '--resume', type='str', dest='load', default=False, help='load pretrained model')
    parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', default='./checkpoint/', help='checkpoint path')
    parser.add_option('--data_path', type='str', dest='data_path', default='/research/cbim/vast/yg397/vision_transformer/dataset/resampled_dataset/', help='dataset path')

    parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/', help='log path')
    parser.add_option('-m', type='str', dest='model', default='UTNet', help='use which model')
    parser.add_option('--num_class', type='int', dest='num_class', default=4, help='number of segmentation classes')
    parser.add_option('--base_chan', type='int', dest='base_chan', default=32, help='number of channels of first expansion in UNet')
    parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test', help='unique experiment name')
    parser.add_option('--rlt', type='float', dest='rlt', default=1, help='relation between CE/FL and dice')
    parser.add_option('--weight', type='float', dest='weight',
                      default=[0.5,1,1,1] , help='weight each class in loss function')
    parser.add_option('--weight_decay', type='float', dest='weight_decay',
                      default=0.0001)
    parser.add_option('--scale', type='float', dest='scale', default=0.30)
    parser.add_option('--rotate', type='float', dest='rotate', default=180)
    parser.add_option('--crop_size', type='int', dest='crop_size', default=256)
    parser.add_option('--domain', type='str', dest='domain', default='A')
    parser.add_option('--aux_weight', type='float', dest='aux_weight', default=[1, 0.4, 0.2, 0.1])
    parser.add_option('--reduce_size', dest='reduce_size', default=8, type='int')
    parser.add_option('--block_list', dest='block_list', default='1234', type='str')
    parser.add_option('--num_blocks', dest='num_blocks', default=[1,1,1,1], type='string', action='callback', callback=get_comma_separated_int_args)
    parser.add_option('--aux_loss', dest='aux_loss', action='store_true', help='using aux loss for deep supervision')

    
    parser.add_option('--gpu', type='str', dest='gpu', default='0')
    options, args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = options.gpu

    print('Using model:', options.model)

    if options.model == 'UTNet':
        net = UTNet(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
    elif options.model == 'UTNet_encoder':
        # Apply transformer blocks only in the encoder
        net = UTNet_Encoderonly(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
    elif options.model =='TransUNet':
        from model.transunet import VisionTransformer as ViT_seg
        from model.transunet import CONFIGS as CONFIGS_ViT_seg
        config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
        config_vit.n_classes = 4 
        config_vit.n_skip = 3 
        config_vit.patches.grid = (int(256/16), int(256/16))
        net = ViT_seg(config_vit, img_size=256, num_classes=4)
        #net.load_from(weights=np.load('./initmodel/R50+ViT-B_16.npz')) # uncomment this to use pretrain model download from TransUnet git repo

    elif options.model == 'ResNet_UTNet':
        from model.resnet_utnet import ResNet_UTNet
        net = ResNet_UTNet(1, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True)
    
    elif options.model == 'SwinUNet':
        from model.swin_unet import SwinUnet, SwinUnet_config
        config = SwinUnet_config()
        net = SwinUnet(config, img_size=224, num_classes=options.num_class)
        net.load_from('./initmodel/swin_tiny_patch4_window7_224.pth')


    else:
        raise NotImplementedError(options.model + " has not been implemented")
    if options.load:
        net.load_state_dict(torch.load(options.load))
        print('Model loaded from {}'.format(options.load))
    
    param_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    
    print(net)
    print(param_num)
    
    net.cuda()
    
    train_net(net, options)

    print('done')

    sys.exit(0)


================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/lookup_tables.py
================================================
# Copyright 2018 Google Inc. All Rights Reserved.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#      http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS-IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

"""Lookup tables used by surface distance metrics."""



from __future__ import absolute_import

from __future__ import division

from __future__ import print_function



import math

import numpy as np



ENCODE_NEIGHBOURHOOD_3D_KERNEL = np.array([[[128, 64], [32, 16]], [[8, 4],

                                                                   [2, 1]]])



# _NEIGHBOUR_CODE_TO_NORMALS is a lookup table.

# For every binary neighbour code

# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)

# it contains the surface normals of the triangles (called "surfel" for

# "surface element" in the following). The length of the normal

# vector encodes the surfel area.

#

# created using the marching_cube algorithm

# see e.g. https://en.wikipedia.org/wiki/Marching_cubes

# pylint: disable=line-too-long

_NEIGHBOUR_CODE_TO_NORMALS = [

    [[0, 0, 0]],

    [[0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],

    [[0.125, -0.125, 0.125]],

    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],

    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],

    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],

    [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],

    [[0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],

    [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],

    [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],

    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],

    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],

    [[0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],

    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],

    [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],

    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],

    [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],

    [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],

    [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],

    [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],

    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],

    [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],

    [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],

    [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],

    [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],

    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],

    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],

    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],

    [[-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],

    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],

    [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],

    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],

    [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],

    [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],

    [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],

    [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],

    [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],

    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],

    [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],

    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],

    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],

    [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],

    [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],

    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],

    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],

    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],

    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],

    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],

    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],

    [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],

    [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],

    [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],

    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],

    [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],

    [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],

    [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],

    [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],

    [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],

    [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],

    [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],

    [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],

    [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],

    [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],

    [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],

    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],

    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],

    [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],

    [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],

    [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],

    [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],

    [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],

    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],

    [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],

    [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],

    [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],

    [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],

    [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],

    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],

    [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],

    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],

    [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],

    [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],

    [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],

    [[0.125, -0.125, 0.125]],

    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],

    [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],

    [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],

    [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],

    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],

    [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],

    [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],

    [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],

    [[0.125, -0.125, -0.125]],

    [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],

    [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],

    [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],

    [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],

    [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],

    [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],

    [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],

    [[-0.125, 0.125, 0.125]],

    [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],

    [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],

    [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],

    [[0.125, -0.125, 0.125]],

    [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],

    [[-0.125, -0.125, 0.125]],

    [[0.125, 0.125, 0.125]],

    [[0, 0, 0]]]

# pylint: enable=line-too-long





def create_table_neighbour_code_to_surface_area(spacing_mm):

  """Returns an array mapping neighbourhood code to the surface elements area.

  Note that the normals encode the initial surface area. This function computes

  the area corresponding to the given `spacing_mm`.

  Args:

    spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2

      direction.

  """

  # compute the area for all 256 possible surface elements

  # (given a 2x2x2 neighbourhood) according to the spacing_mm

  neighbour_code_to_surface_area = np.zeros([256])

  for code in range(256):

    normals = np.array(_NEIGHBOUR_CODE_TO_NORMALS[code])

    sum_area = 0

    for normal_idx in range(normals.shape[0]):

      # normal vector

      n = np.zeros([3])

      n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2]

      n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2]

      n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1]

      area = np.linalg.norm(n)

      sum_area += area

    neighbour_code_to_surface_area[code] = sum_area



  return neighbour_code_to_surface_area





# In the neighbourhood, points are ordered: top left, top right, bottom left,

# bottom right.

ENCODE_NEIGHBOURHOOD_2D_KERNEL = np.array([[8, 4], [2, 1]])





def create_table_neighbour_code_to_contour_length(spacing_mm):

  """Returns an array mapping neighbourhood code to the contour length.

  For the list of possible cases and their figures, see page 38 from:

  https://nccastaff.bournemouth.ac.uk/jmacey/MastersProjects/MSc14/06/thesis.pdf

  In 2D, each point has 4 neighbors. Thus, are 16 configurations. A

  configuration is encoded with '1' meaning "inside the object" and '0' "outside

  the object". The points are ordered: top left, top right, bottom left, bottom

  right.

  The x0 axis is assumed vertical downward, and the x1 axis is horizontal to the

  right:

   (0, 0) --> (0, 1)

     |

   (1, 0)

  Args:

    spacing_mm: 2-element list-like structure. Voxel spacing in x0 and x1

      directions.

  """

  neighbour_code_to_contour_length = np.zeros([16])



  vertical = spacing_mm[0]

  horizontal = spacing_mm[1]

  diag = 0.5 * math.sqrt(spacing_mm[0]**2 + spacing_mm[1]**2)

  # pyformat: disable

  neighbour_code_to_contour_length[int("00"

                                       "01", 2)] = diag



  neighbour_code_to_contour_length[int("00"

                                       "10", 2)] = diag



  neighbour_code_to_contour_length[int("00"

                                       "11", 2)] = horizontal



  neighbour_code_to_contour_length[int("01"

                                       "00", 2)] = diag



  neighbour_code_to_contour_length[int("01"

                                       "01", 2)] = vertical



  neighbour_code_to_contour_length[int("01"

                                       "10", 2)] = 2*diag



  neighbour_code_to_contour_length[int("01"

                                       "11", 2)] = diag



  neighbour_code_to_contour_length[int("10"

                                       "00", 2)] = diag



  neighbour_code_to_contour_length[int("10"

                                       "01", 2)] = 2*diag



  neighbour_code_to_contour_length[int("10"

                                       "10", 2)] = vertical



  neighbour_code_to_contour_length[int("10"

                                       "11", 2)] = diag



  neighbour_code_to_contour_length[int("11"

                                       "00", 2)] = horizontal



  neighbour_code_to_contour_length[int("11"

                                       "01", 2)] = diag



  neighbour_code_to_contour_length[int("11"

                                       "10", 2)] = diag

  # pyformat: enable



  return neighbour_code_to_contour_length


================================================
FILE: utils/metrics.py
================================================
# Copyright 2018 Google Inc. All Rights Reserved.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#      http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS-IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

"""Module exposing surface distance based measures."""



from __future__ import absolute_import

from __future__ import division

from __future__ import print_function



from . import lookup_tables  # pylint: disable=relative-beyond-top-level

import numpy as np

from scipy import ndimage





def _assert_is_numpy_array(name, array):

  """Raises an exception if `array` is not a numpy array."""

  if not isinstance(array, np.ndarray):

    raise ValueError("The argument {!r} should be a numpy array, not a "

                     "{}".format(name, type(array)))





def _check_nd_numpy_array(name, array, num_dims):

  """Raises an exception if `array` is not a `num_dims`-D numpy array."""

  if len(array.shape) != num_dims:

    raise ValueError("The argument {!r} should be a {}D array, not of "

                     "shape {}".format(name, num_dims, array.shape))





def _check_2d_numpy_array(name, array):

  _check_nd_numpy_array(name, array, num_dims=2)





def _check_3d_numpy_array(name, array):

  _check_nd_numpy_array(name, array, num_dims=3)





def _assert_is_bool_numpy_array(name, array):

  _assert_is_numpy_array(name, array)

  if array.dtype != np.bool:

    raise ValueError("The argument {!r} should be a numpy array of type bool, "

                     "not {}".format(name, array.dtype))





def _compute_bounding_box(mask):

  """Computes the bounding box of the masks.

  This function generalizes to arbitrary number of dimensions great or equal

  to 1.

  Args:

    mask: The 2D or 3D numpy mask, where '0' means background and non-zero means

      foreground.

  Returns:

    A tuple:

     - The coordinates of the first point of the bounding box (smallest on all

       axes), or `None` if the mask contains only zeros.

     - The coordinates of the second point of the bounding box (greatest on all

       axes), or `None` if the mask contains only zeros.

  """

  num_dims = len(mask.shape)

  bbox_min = np.zeros(num_dims, np.int64)

  bbox_max = np.zeros(num_dims, np.int64)



  # max projection to the x0-axis

  proj_0 = np.amax(mask, axis=tuple(range(num_dims))[1:])

  idx_nonzero_0 = np.nonzero(proj_0)[0]

  if len(idx_nonzero_0) == 0:  # pylint: disable=g-explicit-length-test

    return None, None



  bbox_min[0] = np.min(idx_nonzero_0)

  bbox_max[0] = np.max(idx_nonzero_0)



  # max projection to the i-th-axis for i in {1, ..., num_dims - 1}

  for axis in range(1, num_dims):

    max_over_axes = list(range(num_dims))  # Python 3 compatible

    max_over_axes.pop(axis)  # Remove the i-th dimension from the max

    max_over_axes = tuple(max_over_axes)  # numpy expects a tuple of ints

    proj = np.amax(mask, axis=max_over_axes)

    idx_nonzero = np.nonzero(proj)[0]

    bbox_min[axis] = np.min(idx_nonzero)

    bbox_max[axis] = np.max(idx_nonzero)



  return bbox_min, bbox_max





def _crop_to_bounding_box(mask, bbox_min, bbox_max):

  """Crops a 2D or 3D mask to the bounding box specified by `bbox_{min,max}`."""

  # we need to zeropad the cropped region with 1 voxel at the lower,

  # the right (and the back on 3D) sides. This is required to obtain the

  # "full" convolution result with the 2x2 (or 2x2x2 in 3D) kernel.

  #TODO:  This is correct only if the object is interior to the

  # bounding box.

  cropmask = np.zeros((bbox_max - bbox_min) + 2, np.uint8)



  num_dims = len(mask.shape)

  # pyformat: disable

  if num_dims == 2:

    cropmask[0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,

                                bbox_min[1]:bbox_max[1] + 1]

  elif num_dims == 3:

    cropmask[0:-1, 0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,

                                      bbox_min[1]:bbox_max[1] + 1,

                                      bbox_min[2]:bbox_max[2] + 1]

  # pyformat: enable

  else:

    assert False



  return cropmask





def _sort_distances_surfels(distances, surfel_areas):

  """Sorts the two list with respect to the tuple of (distance, surfel_area).



  Args:

    distances: The distances from A to B (e.g. `distances_gt_to_pred`).

    surfel_areas: The surfel areas for A (e.g. `surfel_areas_gt`).



  Returns:

    A tuple of the sorted (distances, surfel_areas).

  """

  sorted_surfels = np.array(sorted(zip(distances, surfel_areas)))

  return sorted_surfels[:, 0], sorted_surfels[:, 1]





def compute_surface_distances(mask_gt,

                              mask_pred,

                              spacing_mm):

  """Computes closest distances from all surface points to the other surface.



  This function can be applied to 2D or 3D tensors. For 2D, both masks must be

  2D and `spacing_mm` must be a 2-element list. For 3D, both masks must be 3D

  and `spacing_mm` must be a 3-element list. The description is done for the 2D

  case, and the formulation for the 3D case is present is parenthesis,

  introduced by "resp.".



  Finds all contour elements (resp surface elements "surfels" in 3D) in the

  ground truth mask `mask_gt` and the predicted mask `mask_pred`, computes their

  length in mm (resp. area in mm^2) and the distance to the closest point on the

  other contour (resp. surface). It returns two sorted lists of distances

  together with the corresponding contour lengths (resp. surfel areas). If one

  of the masks is empty, the corresponding lists are empty and all distances in

  the other list are `inf`.



  Args:

    mask_gt: 2-dim (resp. 3-dim) bool Numpy array. The ground truth mask.

    mask_pred: 2-dim (resp. 3-dim) bool Numpy array. The predicted mask.

    spacing_mm: 2-element (resp. 3-element) list-like structure. Voxel spacing

      in x0 anx x1 (resp. x0, x1 and x2) directions.



  Returns:

    A dict with:

    "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm

        from all ground truth surface elements to the predicted surface,

        sorted from smallest to largest.

    "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm

        from all predicted surface elements to the ground truth surface,

        sorted from smallest to largest.

    "surfel_areas_gt": 1-dim numpy array of type float. The length of the

      of the ground truth contours in mm (resp. the surface elements area in

      mm^2) in the same order as distances_gt_to_pred.

    "surfel_areas_pred": 1-dim numpy array of type float. The length of the

      of the predicted contours in mm (resp. the surface elements area in

      mm^2) in the same order as distances_gt_to_pred.



  Raises:

    ValueError: If the masks and the `spacing_mm` arguments are of incompatible

      shape or type. Or if the masks are not 2D or 3D.

  """

  # The terms used in this function are for the 3D case. In particular, surface

  # in 2D stands for contours in 3D. The surface elements in 3D correspond to

  # the line elements in 2D.



  _assert_is_bool_numpy_array("mask_gt", mask_gt)

  _assert_is_bool_numpy_array("mask_pred", mask_pred)



  if not len(mask_gt.shape) == len(mask_pred.shape) == len(spacing_mm):

    raise ValueError("The arguments must be of compatible shape. Got mask_gt "

                     "with {} dimensions ({}) and mask_pred with {} dimensions "

                     "({}), while the spacing_mm was {} elements.".format(

                         len(mask_gt.shape),

                         mask_gt.shape, len(mask_pred.shape), mask_pred.shape,

                         len(spacing_mm)))



  num_dims = len(spacing_mm)

  if num_dims == 2:

    _check_2d_numpy_array("mask_gt", mask_gt)

    _check_2d_numpy_array("mask_pred", mask_pred)



    # compute the area for all 16 possible surface elements

    # (given a 2x2 neighbourhood) according to the spacing_mm

    neighbour_code_to_surface_area = (

        lookup_tables.create_table_neighbour_code_to_contour_length(spacing_mm))

    kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_2D_KERNEL

    full_true_neighbours = 0b1111

  elif num_dims == 3:

    _check_3d_numpy_array("mask_gt", mask_gt)

    _check_3d_numpy_array("mask_pred", mask_pred)



    # compute the area for all 256 possible surface elements

    # (given a 2x2x2 neighbourhood) according to the spacing_mm

    neighbour_code_to_surface_area = (

        lookup_tables.create_table_neighbour_code_to_surface_area(spacing_mm))

    kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_3D_KERNEL

    full_true_neighbours = 0b11111111

  else:

    raise ValueError("Only 2D and 3D masks are supported, not "

                     "{}D.".format(num_dims))



  # compute the bounding box of the masks to trim the volume to the smallest

  # possible processing subvolume

  bbox_min, bbox_max = _compute_bounding_box(mask_gt | mask_pred)

  # Both the min/max bbox are None at the same time, so we only check one.

  if bbox_min is None:

    return {

        "distances_gt_to_pred": np.array([]),

        "distances_pred_to_gt": np.array([]),

        "surfel_areas_gt": np.array([]),

        "surfel_areas_pred": np.array([]),

    }



  # crop the processing subvolume.

  cropmask_gt = _crop_to_bounding_box(mask_gt, bbox_min, bbox_max)

  cropmask_pred = _crop_to_bounding_box(mask_pred, bbox_min, bbox_max)



  # compute the neighbour code (local binary pattern) for each voxel

  # the resulting arrays are spacially shifted by minus half a voxel in each

  # axis.

  # i.e. the points are located at the corners of the original voxels

  neighbour_code_map_gt = ndimage.filters.correlate(

      cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)

  neighbour_code_map_pred = ndimage.filters.correlate(

      cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)



  # create masks with the surface voxels

  borders_gt = ((neighbour_code_map_gt != 0) &

                (neighbour_code_map_gt != full_true_neighbours))

  borders_pred = ((neighbour_code_map_pred != 0) &

                  (neighbour_code_map_pred != full_true_neighbours))



  # compute the distance transform (closest distance of each voxel to the

  # surface voxels)

  if borders_gt.any():

    distmap_gt = ndimage.morphology.distance_transform_edt(

        ~borders_gt, sampling=spacing_mm)

  else:

    distmap_gt = np.Inf * np.ones(borders_gt.shape)



  if borders_pred.any():

    distmap_pred = ndimage.morphology.distance_transform_edt(

        ~borders_pred, sampling=spacing_mm)

  else:

    distmap_pred = np.Inf * np.ones(borders_pred.shape)



  # compute the area of each surface element

  surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]

  surface_area_map_pred = neighbour_code_to_surface_area[

      neighbour_code_map_pred]



  # create a list of all surface elements with distance and area

  distances_gt_to_pred = distmap_pred[borders_gt]

  distances_pred_to_gt = distmap_gt[borders_pred]

  surfel_areas_gt = surface_area_map_gt[borders_gt]

  surfel_areas_pred = surface_area_map_pred[borders_pred]



  # sort them by distance

  if distances_gt_to_pred.shape != (0,):

    distances_gt_to_pred, surfel_areas_gt = _sort_distances_surfels(

        distances_gt_to_pred, surfel_areas_gt)



  if distances_pred_to_gt.shape != (0,):

    distances_pred_to_gt, surfel_areas_pred = _sort_distances_surfels(

        distances_pred_to_gt, surfel_areas_pred)



  return {

      "distances_gt_to_pred": distances_gt_to_pred,

      "distances_pred_to_gt": distances_pred_to_gt,

      "surfel_areas_gt": surfel_areas_gt,

      "surfel_areas_pred": surfel_areas_pred,

  }





def compute_average_surface_distance(surface_distances):

  """Returns the average surface distance.



  Computes the average surface distances by correctly taking the area of each

  surface element into account. Call compute_surface_distances(...) before, to

  obtain the `surface_distances` dict.



  Args:

    surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"

    "surfel_areas_gt", "surfel_areas_pred" created by

    compute_surface_distances()



  Returns:

    A tuple with two float values:

      - the average distance (in mm) from the ground truth surface to the

        predicted surface

      - the average distance from the predicted surface to the ground truth

        surface.

  """

  distances_gt_to_pred = surface_distances["distances_gt_to_pred"]

  distances_pred_to_gt = surface_distances["distances_pred_to_gt"]

  surfel_areas_gt = surface_distances["surfel_areas_gt"]

  surfel_areas_pred = surface_distances["surfel_areas_pred"]

  average_distance_gt_to_pred = (

      np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt))

  average_distance_pred_to_gt = (

      np.sum(distances_pred_to_gt * surfel_areas_pred) /

      np.sum(surfel_areas_pred))

  return (average_distance_gt_to_pred, average_distance_pred_to_gt)





def compute_robust_hausdorff(surface_distances, percent):

  """Computes the robust Hausdorff distance.



  Computes the robust Hausdorff distance. "Robust", because it uses the

  `percent` percentile of the distances instead of the maximum distance. The

  percentage is computed by correctly taking the area of each surface element

  into account.



  Args:

    surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"

      "surfel_areas_gt", "surfel_areas_pred" created by

      compute_surface_distances()

    percent: a float value between 0 and 100.



  Returns:

    a float value. The robust Hausdorff distance in mm.

  """

  distances_gt_to_pred = surface_distances["distances_gt_to_pred"]

  distances_pred_to_gt = surface_distances["distances_pred_to_gt"]

  surfel_areas_gt = surface_distances["surfel_areas_gt"]

  surfel_areas_pred = surface_distances["surfel_areas_pred"]

  if len(distances_gt_to_pred) > 0:  # pylint: disable=g-explicit-length-test

    surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt)

    idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)

    perc_distance_gt_to_pred = distances_gt_to_pred[

        min(idx, len(distances_gt_to_pred)-1)]

  else:

    perc_distance_gt_to_pred = np.Inf



  if len(distances_pred_to_gt) > 0:  # pylint: disable=g-explicit-length-test

    surfel_areas_cum_pred = (np.cumsum(surfel_areas_pred) /

                             np.sum(surfel_areas_pred))

    idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)

    perc_distance_pred_to_gt = distances_pred_to_gt[

        min(idx, len(distances_pred_to_gt)-1)]

  else:

    perc_distance_pred_to_gt = np.Inf



  return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt)





def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):

  """Computes the overlap of the surfaces at a specified tolerance.



  Computes the overlap of the ground truth surface with the predicted surface

  and vice versa allowing a specified tolerance (maximum surface-to-surface

  distance that is regarded as overlapping). The overlapping fraction is

  computed by correctly taking the area of each surface element into account.



  Args:

    surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"

      "surfel_areas_gt", "surfel_areas_pred" created by

      compute_surface_distances()

    tolerance_mm: a float value. The tolerance in mm



  Returns:

    A tuple of two float values. The overlap fraction in [0.0, 1.0] of the

    ground truth surface with the predicted surface and vice versa.

  """

  distances_gt_to_pred = surface_distances["distances_gt_to_pred"]

  distances_pred_to_gt = surface_distances["distances_pred_to_gt"]

  surfel_areas_gt = surface_distances["surfel_areas_gt"]

  surfel_areas_pred = surface_distances["surfel_areas_pred"]

  rel_overlap_gt = (

      np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) /

      np.sum(surfel_areas_gt))

  rel_overlap_pred = (

      np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) /

      np.sum(surfel_areas_pred))

  return (rel_overlap_gt, rel_overlap_pred)





def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):

  """Computes the _surface_ DICE coefficient at a specified tolerance.



  Computes the _surface_ DICE coefficient at a specified tolerance. Not to be

  confused with the standard _volumetric_ DICE coefficient. The surface DICE

  measures the overlap of two surfaces instead of two volumes. A surface

  element is counted as overlapping (or touching), when the closest distance to

  the other surface is less or equal to the specified tolerance. The DICE

  coefficient is in the range between 0.0 (no overlap) to 1.0 (perfect overlap).



  Args:

    surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"

      "surfel_areas_gt", "surfel_areas_pred" created by

      compute_surface_distances()

    tolerance_mm: a float value. The tolerance in mm



  Returns:

    A float value. The surface DICE coefficient in [0.0, 1.0].

  """

  distances_gt_to_pred = surface_distances["distances_gt_to_pred"]

  distances_pred_to_gt = surface_distances["distances_pred_to_gt"]

  surfel_areas_gt = surface_distances["surfel_areas_gt"]

  surfel_areas_pred = surface_distances["surfel_areas_pred"]

  overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])

  overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])

  surface_dice = (overlap_gt + overlap_pred) / (

      np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))

  return surface_dice





def compute_dice_coefficient(mask_gt, mask_pred):

  """Computes soerensen-dice coefficient.



  compute the soerensen-dice coefficient between the ground truth mask `mask_gt`

  and the predicted mask `mask_pred`.



  Args:

    mask_gt: 3-dim Numpy array of type bool. The ground truth mask.

    mask_pred: 3-dim Numpy array of type bool. The predicted mask.



  Returns:

    the dice coeffcient as float. If both masks are empty, the result is NaN.

  """

  volume_sum = mask_gt.sum() + mask_pred.sum()

  if volume_sum == 0:

    return np.NaN

  volume_intersect = (mask_gt & mask_pred).sum()

  return 2*volume_intersect / volume_sum


================================================
FILE: utils/utils.py
================================================
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import SimpleITK as sitk
import pdb 


def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, epoch):
    
    writer.add_scalar('Test_Dice/%s_AVG'%name, dice_list.mean(), epoch+1)
    for idx in range(3):
        writer.add_scalar('Test_Dice/%s_Dice%d'%(name, idx+1), dice_list[idx], epoch+1)
    writer.add_scalar('Test_ASD/%s_AVG'%name, ASD_list.mean(), epoch+1)
    for idx in range(3):
        writer.add_scalar('Test_ASD/%s_ASD%d'%(name, idx+1), ASD_list[idx], epoch+1)
    writer.add_scalar('Test_HD/%s_AVG'%name, HD_list.mean(), epoch+1)
    for idx in range(3):
        writer.add_scalar('Test_HD/%s_HD%d'%(name, idx+1), HD_list[idx], epoch+1)


def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1):

    if epoch >= 0 and epoch <= warmup_epoch:
        lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))
        if epoch == warmup_epoch:
            lr = init_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        return lr

    flag = False
    for i in range(len(lr_decay_epoch)):
        if epoch == lr_decay_epoch[i]:
            flag = True
            break

    if flag == True:
        lr = init_lr * gamma**(i+1)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    else:
        return optimizer.param_groups[0]['lr']

    return lr

def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch):

    if epoch >= 0 and epoch <= warmup_epoch:
        lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.))
        if epoch == warmup_epoch:
            lr = init_lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        return lr

    else:
        lr = init_lr * (1 - epoch / max_epoch)**0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    return lr



def cal_dice(pred, target, C): 
    N = pred.shape[0]
    target_mask = target.data.new(N, C).fill_(0)
    target_mask.scatter_(1, target, 1.) 

    pred_mask = pred.data.new(N, C).fill_(0)
    pred_mask.scatter_(1, pred, 1.) 

    intersection= pred_mask * target_mask
    summ = pred_mask + target_mask

    intersection = intersection.sum(0).type(torch.float32)
    summ = summ.sum(0).type(torch.float32)
    
    eps = torch.rand(C, dtype=torch.float32)
    eps = eps.fill_(1e-7)

    summ += eps.cuda()
    dice = 2 * intersection / summ

    return dice, intersection, summ

def cal_asd(itkPred, itkGT):
    
    reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkGT, squaredDistance=False))
    reference_surface = sitk.LabelContour(itkGT)

    statistics_image_filter = sitk.StatisticsImageFilter()
    statistics_image_filter.Execute(reference_surface)
    num_reference_surface_pixels = int(statistics_image_filter.GetSum())

    segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(itkPred, squaredDistance=False))
    segmented_surface = sitk.LabelContour(itkPred)

    seg2ref_distance_map = reference_distance_map * sitk.Cast(segmented_surface, sitk.sitkFloat32)
    ref2seg_distance_map = segmented_distance_map * sitk.Cast(reference_surface, sitk.sitkFloat32)

    statistics_image_filter.Execute(segmented_surface)
    num_segmented_surface_pixels = int(statistics_image_filter.GetSum())

    seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
    seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0])
    seg2ref_distances = seg2ref_distances + \
                        list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))
    ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
    ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0])
    ref2seg_distances = ref2seg_distances + \
                        list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))
    
    all_surface_distances = seg2ref_distances + ref2seg_distances

    ASD = np.mean(all_surface_distances)

    return ASD

Download .txt
gitextract_vjhzdth0/

├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess.py
├── dataset/
│   ├── Testing_A.csv
│   ├── Testing_B.csv
│   ├── Testing_C.csv
│   ├── Testing_D.csv
│   ├── Training_A.csv
│   ├── Training_B.csv
│   ├── Training_C.csv
│   ├── Training_D.csv
│   └── info.csv
├── dataset_domain.py
├── losses.py
├── model/
│   ├── __init__.py
│   ├── conv_trans_utils.py
│   ├── resnet_utnet.py
│   ├── swin_unet.py
│   ├── transunet.py
│   ├── unet_utils.py
│   └── utnet.py
├── train_deep.py
└── utils/
    ├── __init__.py
    ├── lookup_tables.py
    ├── metrics.py
    └── utils.py
Download .txt
SYMBOL INDEX (230 symbols across 13 files)

FILE: data_preprocess.py
  function ResampleXYZAxis (line 7) | def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear):
  function ResampleFullImageToRef (line 23) | def ResampleFullImageToRef(imImage, imRef, interp=sitk.sitkNearestNeighb...
  function ResampleCMRImage (line 37) | def ResampleCMRImage(imImage, imLabel, save_path, patient_name, name, ta...

FILE: dataset_domain.py
  class CMRDataset (line 13) | class CMRDataset(Dataset):
    method __init__ (line 14) | def __init__(self, dataset_dir, mode='train', domain='A', crop_size=25...
    method __len__ (line 103) | def __len__(self):
    method preprocess (line 106) | def preprocess(self, itk_img, itk_lab):
    method __getitem__ (line 131) | def __getitem__(self, idx):
    method randcrop (line 162) | def randcrop(self, img, label):
    method center_crop (line 177) | def center_crop(self, img, label):
    method random_zoom_rotate (line 191) | def random_zoom_rotate(self, img, label):

FILE: losses.py
  class DiceLoss (line 8) | class DiceLoss(nn.Module):
    method __init__ (line 10) | def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):
    method forward (line 18) | def forward(self, preds, targets):
  class FocalLoss (line 60) | class FocalLoss(nn.Module):
    method __init__ (line 61) | def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
    method forward (line 72) | def forward(self, preds, targets):

FILE: model/conv_trans_utils.py
  function conv3x3 (line 9) | def conv3x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 11) | def conv1x1(in_planes, out_planes, stride=1):
  class depthwise_separable_conv (line 14) | class depthwise_separable_conv(nn.Module):
    method __init__ (line 15) | def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, ...
    method forward (line 20) | def forward(self, x):
  class Mlp (line 26) | class Mlp(nn.Module):
    method __init__ (line 27) | def __init__(self, in_ch, hid_ch=None, out_ch=None, act_layer=nn.GELU,...
    method forward (line 37) | def forward(self, x):
  class BasicBlock (line 46) | class BasicBlock(nn.Module):
    method __init__ (line 48) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 65) | def forward(self, x):
  class BasicTransBlock (line 80) | class BasicTransBlock(nn.Module):
    method __init__ (line 82) | def __init__(self, in_ch, heads, dim_head, attn_drop=0., proj_drop=0.,...
    method forward (line 93) | def forward(self, x):
  class BasicTransDecoderBlock (line 109) | class BasicTransDecoderBlock(nn.Module):
    method __init__ (line 111) | def __init__(self, in_ch, out_ch, heads, dim_head, attn_drop=0., proj_...
    method forward (line 124) | def forward(self, x1, x2):
  class LinearAttention (line 150) | class LinearAttention(nn.Module):
    method __init__ (line 152) | def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=...
    method forward (line 180) | def forward(self, x):
  class LinearAttentionDecoder (line 217) | class LinearAttentionDecoder(nn.Module):
    method __init__ (line 219) | def __init__(self, in_dim, out_dim, heads=4, dim_head=64, attn_drop=0....
    method forward (line 247) | def forward(self, q, x):
  class RelativePositionEmbedding (line 284) | class RelativePositionEmbedding(nn.Module):
    method __init__ (line 286) | def __init__(self, dim, shape):
    method forward (line 303) | def forward(self, q, Nh, H, W, dim_head):
    method relative_logits_1d (line 316) | def relative_logits_1d(self, q, rel_k, case):
  class RelativePositionBias (line 344) | class RelativePositionBias(nn.Module):
    method __init__ (line 348) | def __init__(self, num_heads, h, w):
    method forward (line 371) | def forward(self, H, W):
  class down_block_trans (line 385) | class down_block_trans(nn.Module):
    method __init__ (line 386) | def __init__(self, in_ch, out_ch, num_block, bottleneck=False, maxpool...
    method forward (line 411) | def forward(self, x):
  class up_block_trans (line 418) | class up_block_trans(nn.Module):
    method __init__ (line 419) | def __init__(self, in_ch, out_ch, num_block, bottleneck=False, heads=4...
    method forward (line 439) | def forward(self, x1, x2):
  class block_trans (line 447) | class block_trans(nn.Module):
    method __init__ (line 448) | def __init__(self, in_ch, num_block, heads=4, dim_head=64, attn_drop=0...
    method forward (line 462) | def forward(self, x):

FILE: model/resnet_utnet.py
  class ResNet_UTNet (line 13) | class ResNet_UTNet(nn.Module):
    method __init__ (line 14) | def __init__(self, in_ch, num_class, reduce_size=8, block_list='234', ...
    method forward (line 51) | def forward(self, x):

FILE: model/swin_unet.py
  class Mlp (line 48) | class Mlp(nn.Module):
    method __init__ (line 50) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 68) | def forward(self, x):
  function window_partition (line 86) | def window_partition(x, window_size):
  function window_reverse (line 116) | def window_reverse(windows, window_size, H, W):
  class WindowAttention (line 150) | class WindowAttention(nn.Module):
    method __init__ (line 178) | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scal...
    method forward (line 246) | def forward(self, x, mask=None):
    method extra_repr (line 312) | def extra_repr(self) -> str:
    method flops (line 318) | def flops(self, N):
  class SwinTransformerBlock (line 346) | class SwinTransformerBlock(nn.Module):
    method __init__ (line 384) | def __init__(self, dim, input_resolution, num_heads, window_size=7, sh...
    method forward (line 486) | def forward(self, x):
    method extra_repr (line 564) | def extra_repr(self) -> str:
    method flops (line 571) | def flops(self):
  class PatchMerging (line 601) | class PatchMerging(nn.Module):
    method __init__ (line 619) | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
    method forward (line 633) | def forward(self, x):
    method extra_repr (line 679) | def extra_repr(self) -> str:
    method flops (line 685) | def flops(self):
  class PatchExpand (line 697) | class PatchExpand(nn.Module):
    method __init__ (line 699) | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.L...
    method forward (line 713) | def forward(self, x):
  class FinalPatchExpand_X4 (line 745) | class FinalPatchExpand_X4(nn.Module):
    method __init__ (line 747) | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.L...
    method forward (line 765) | def forward(self, x):
  class BasicLayer (line 797) | class BasicLayer(nn.Module):
    method __init__ (line 837) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
    method forward (line 893) | def forward(self, x):
    method extra_repr (line 913) | def extra_repr(self) -> str:
    method flops (line 919) | def flops(self):
  class BasicLayer_up (line 935) | class BasicLayer_up(nn.Module):
    method __init__ (line 975) | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
    method forward (line 1031) | def forward(self, x):
  class PatchEmbed (line 1051) | class PatchEmbed(nn.Module):
    method __init__ (line 1073) | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=9...
    method forward (line 1111) | def forward(self, x):
    method flops (line 1130) | def flops(self):
  class SwinTransformerSys (line 1146) | class SwinTransformerSys(nn.Module):
    method __init__ (line 1198) | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes...
    method _init_weights (line 1384) | def _init_weights(self, m):
    method no_weight_decay (line 1404) | def no_weight_decay(self):
    method no_weight_decay_keywords (line 1412) | def no_weight_decay_keywords(self):
    method forward_features (line 1420) | def forward_features(self, x):
    method forward_up_features (line 1452) | def forward_up_features(self, x, x_downsample):
    method up_x4 (line 1478) | def up_x4(self, x):
    method forward (line 1504) | def forward(self, x):
    method flops (line 1518) | def flops(self):
  class SwinUnet_config (line 1541) | class SwinUnet_config():
    method __init__ (line 1542) | def __init__(self):
  class SwinUnet (line 1559) | class SwinUnet(nn.Module):
    method __init__ (line 1561) | def __init__(self, config, img_size=224, num_classes=21843, zero_head=...
    method forward (line 1607) | def forward(self, x):
    method load_from (line 1619) | def load_from(self, pretrained_path):

FILE: model/transunet.py
  function np2th (line 72) | def np2th(weights, conv=False):
  function swish (line 86) | def swish(x):
  class Attention (line 100) | class Attention(nn.Module):
    method __init__ (line 102) | def __init__(self, config, vis):
    method transpose_for_scores (line 136) | def transpose_for_scores(self, x):
    method forward (line 146) | def forward(self, hidden_states):
  class Mlp (line 194) | class Mlp(nn.Module):
    method __init__ (line 196) | def __init__(self, config):
    method _init_weights (line 214) | def _init_weights(self):
    method forward (line 226) | def forward(self, x):
  class Embeddings (line 244) | class Embeddings(nn.Module):
    method __init__ (line 250) | def __init__(self, config, img_size, in_channels=3):
    method forward (line 308) | def forward(self, x):
  class Block (line 336) | class Block(nn.Module):
    method __init__ (line 338) | def __init__(self, config, vis):
    method forward (line 354) | def forward(self, x):
    method load_from (line 378) | def load_from(self, weights, n_block):
  class Encoder (line 454) | class Encoder(nn.Module):
    method __init__ (line 456) | def __init__(self, config, vis):
    method forward (line 474) | def forward(self, hidden_states):
  class Transformer (line 494) | class Transformer(nn.Module):
    method __init__ (line 496) | def __init__(self, config, img_size, vis):
    method forward (line 506) | def forward(self, input_ids):
  class Conv2dReLU (line 518) | class Conv2dReLU(nn.Sequential):
    method __init__ (line 520) | def __init__(
  class DecoderBlock (line 568) | class DecoderBlock(nn.Module):
    method __init__ (line 570) | def __init__(
    method forward (line 618) | def forward(self, x, skip=None):
  class SegmentationHead (line 636) | class SegmentationHead(nn.Sequential):
    method __init__ (line 640) | def __init__(self, in_channels, out_channels, kernel_size=3, upsamplin...
  class DecoderCup (line 652) | class DecoderCup(nn.Module):
    method __init__ (line 654) | def __init__(self, config):
    method forward (line 710) | def forward(self, hidden_states, features=None):
  class VisionTransformer (line 740) | class VisionTransformer(nn.Module):
    method __init__ (line 742) | def __init__(self, config, img_size=224, num_classes=21843, zero_head=...
    method forward (line 770) | def forward(self, x):
    method load_from (line 786) | def load_from(self, weights):
  function get_b16_config (line 885) | def get_b16_config():
  function get_testing (line 934) | def get_testing():
  function get_r50_b16_config (line 964) | def get_r50_b16_config():
  function get_b32_config (line 1002) | def get_b32_config():
  function get_l16_config (line 1018) | def get_l16_config():
  function get_r50_l16_config (line 1064) | def get_r50_l16_config():
  function get_l32_config (line 1098) | def get_l32_config():
  function get_h14_config (line 1112) | def get_h14_config():
  function np2th (line 1176) | def np2th(weights, conv=False):
  class StdConv2d (line 1190) | class StdConv2d(nn.Conv2d):
    method forward (line 1194) | def forward(self, x):
  function conv3x3 (line 1210) | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  function conv1x1 (line 1220) | def conv1x1(cin, cout, stride=1, bias=False):
  class PreActBottleneck (line 1230) | class PreActBottleneck(nn.Module):
    method __init__ (line 1238) | def __init__(self, cin, cout=None, cmid=None, stride=1):
    method forward (line 1274) | def forward(self, x):
    method load_from (line 1306) | def load_from(self, weights, n_block, n_unit):
  class ResNetV2 (line 1378) | class ResNetV2(nn.Module):
    method __init__ (line 1384) | def __init__(self, block_units, width_factor):
    method forward (line 1438) | def forward(self, x):

FILE: model/unet_utils.py
  class DoubleConv (line 6) | class DoubleConv(nn.Module):
    method __init__ (line 12) | def __init__(self, in_channels, out_channels, mid_channels=None):
    method forward (line 38) | def forward(self, x):
  class Down (line 46) | class Down(nn.Module):
    method __init__ (line 52) | def __init__(self, in_channels, out_channels):
    method forward (line 66) | def forward(self, x):
  class Up (line 74) | class Up(nn.Module):
    method __init__ (line 80) | def __init__(self, in_channels, out_channels, bilinear=True):
    method forward (line 104) | def forward(self, x1, x2):
  class OutConv (line 134) | class OutConv(nn.Module):
    method __init__ (line 136) | def __init__(self, in_channels, out_channels):
    method forward (line 144) | def forward(self, x):
  function conv3x3 (line 149) | def conv3x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 151) | def conv1x1(in_planes, out_planes, stride=1):
  class BasicBlock (line 156) | class BasicBlock(nn.Module):
    method __init__ (line 158) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 175) | def forward(self, x):
  class BottleneckBlock (line 190) | class BottleneckBlock(nn.Module):
    method __init__ (line 192) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 212) | def forward(self, x):
  class inconv (line 233) | class inconv(nn.Module):
    method __init__ (line 234) | def __init__(self, in_ch, out_ch, bottleneck=False):
    method forward (line 244) | def forward(self, x):
  class down_block (line 251) | class down_block(nn.Module):
    method __init__ (line 252) | def __init__(self, in_ch, out_ch, scale, num_block, bottleneck=False, ...
    method forward (line 274) | def forward(self, x):
  class up_block (line 280) | class up_block(nn.Module):
    method __init__ (line 281) | def __init__(self, in_ch, out_ch, num_block, scale=(2,2),bottleneck=Fa...
    method forward (line 301) | def forward(self, x1, x2):

FILE: model/utnet.py
  class UTNet (line 11) | class UTNet(nn.Module):
    method __init__ (line 13) | def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, b...
    method forward (line 66) | def forward(self, x):
  class UTNet_Encoderonly (line 105) | class UTNet_Encoderonly(nn.Module):
    method __init__ (line 107) | def __init__(self, in_chan, base_chan, num_classes=1, reduce_size=8, b...
    method forward (line 158) | def forward(self, x):

FILE: train_deep.py
  function train_net (line 28) | def train_net(net, options):
  function validation (line 141) | def validation(net, test_loader, options):
  function cal_distance (line 187) | def cal_distance(label_pred, label_true, spacing):
  function get_comma_separated_int_args (line 210) | def get_comma_separated_int_args(option, opt, value, parser):

FILE: utils/lookup_tables.py
  function create_table_neighbour_code_to_surface_area (line 591) | def create_table_neighbour_code_to_surface_area(spacing_mm):
  function create_table_neighbour_code_to_contour_length (line 655) | def create_table_neighbour_code_to_contour_length(spacing_mm):

FILE: utils/metrics.py
  function _assert_is_numpy_array (line 49) | def _assert_is_numpy_array(name, array):
  function _check_nd_numpy_array (line 63) | def _check_nd_numpy_array(name, array, num_dims):
  function _check_2d_numpy_array (line 77) | def _check_2d_numpy_array(name, array):
  function _check_3d_numpy_array (line 85) | def _check_3d_numpy_array(name, array):
  function _assert_is_bool_numpy_array (line 93) | def _assert_is_bool_numpy_array(name, array):
  function _compute_bounding_box (line 107) | def _compute_bounding_box(mask):
  function _crop_to_bounding_box (line 187) | def _crop_to_bounding_box(mask, bbox_min, bbox_max):
  function _sort_distances_surfels (line 237) | def _sort_distances_surfels(distances, surfel_areas):
  function compute_surface_distances (line 265) | def compute_surface_distances(mask_gt,
  function compute_average_surface_distance (line 579) | def compute_average_surface_distance(surface_distances):
  function compute_robust_hausdorff (line 641) | def compute_robust_hausdorff(surface_distances, percent):
  function compute_surface_overlap_at_tolerance (line 723) | def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
  function compute_surface_dice_at_tolerance (line 785) | def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
  function compute_dice_coefficient (line 845) | def compute_dice_coefficient(mask_gt, mask_pred):

FILE: utils/utils.py
  function log_evaluation_result (line 9) | def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, ep...
  function multistep_lr_scheduler_with_warmup (line 22) | def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup...
  function exp_lr_scheduler_with_warmup (line 49) | def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch...
  function cal_dice (line 69) | def cal_dice(pred, target, C):
  function cal_asd (line 91) | def cal_asd(itkPred, itkGT):
Condensed preview — 27 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (205K chars).
[
  {
    "path": ".gitignore",
    "chars": 76,
    "preview": "__pycache__/\n*.pyc\n*_bkp.py\ncheckpoint/\nlog/\nmodel/backup/\ninitmodel/\n*.swp\n"
  },
  {
    "path": "LICENSE",
    "chars": 1063,
    "preview": "MIT License\n\nCopyright (c) 2021 yhygao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof "
  },
  {
    "path": "README.md",
    "chars": 8914,
    "preview": "# UTNet (Accepted at MICCAI 2021)\nOfficial implementation of [UTNet: A Hybrid Transformer Architecture for Medical Image"
  },
  {
    "path": "data_preprocess.py",
    "chars": 3222,
    "preview": "import numpy as np\nimport SimpleITK as sitk\n\nimport os\nimport pdb \n\ndef ResampleXYZAxis(imImage, space=(1., 1., 1.), int"
  },
  {
    "path": "dataset/Testing_A.csv",
    "chars": 145,
    "preview": "name\nA2L1N6\nB0H7V0\nB3F0V9\nB5F8L9\nB9H8N8\nC4E9I1\nC6E0F9\nC6U4W8\nD1H6U2\nD9F5P1\nE0J7L9\nF0K4T6\nG7S6V0\nG9N5V9\nH1N8S6\nI2J6Z6\nI5L"
  },
  {
    "path": "dataset/Testing_B.csv",
    "chars": 355,
    "preview": "name\nA2H5K9\nA4A8V9\nA5C2D2\nA5D0G0\nA6J0Y2\nA7E4J0\nB1G9J3\nB3S2Z4\nB4E1K1\nB4S1Y2\nB5L5Y4\nB5T6V0\nB7F5P0\nC7L8Z8\nC8I7P7\nD1R0Y5\nD3Q"
  },
  {
    "path": "dataset/Testing_C.csv",
    "chars": 355,
    "preview": "name\nA3H5R1\nA5H1Q2\nA8C5E9\nA9F3T5\nA9L7Y7\nB0L3Y2\nB2L0L2\nB8P5Q9\nC0N8P4\nC7M6W0\nC8J7L5\nC8O0P2\nD1H2O9\nD2U0V0\nD3K5Q2\nD5G3W8\nD6N"
  },
  {
    "path": "dataset/Testing_D.csv",
    "chars": 355,
    "preview": "name\nA1K2P5\nA3P9V7\nA4B9O6\nA4K8R4\nA4R4T0\nA5P5W0\nA5Q1W8\nA6A8H0\nA6B7Y4\nA7F4G2\nB3E2W8\nB6I0T4\nB9G4U2\nC0L7V1\nC5L0R0\nD1L6T4\nD1S"
  },
  {
    "path": "dataset/Training_A.csv",
    "chars": 530,
    "preview": "name\nA0S9V9\nA1D9Z7\nA1E9Q1\nA2C0I1\nA2N8V0\nA3H1O5\nA4B5U4\nA4J4S4\nA4U9V5\nA6B5G9\nA6D5F9\nA7D9L8\nA7M7P8\nA7O4T6\nB0I2Z0\nB2C2Z7\nB2D"
  },
  {
    "path": "dataset/Training_B.csv",
    "chars": 530,
    "preview": "name\nA1D0Q7\nA1O8Z3\nA3B7E5\nA5E0T8\nA6M1Q7\nA7G0P5\nA8C9U8\nA8E1F4\nA9C5P4\nA9E3G9\nA9J5Q7\nA9J8W7\nB0N3W8\nB2D9O2\nB2F4K5\nB3D0N1\nB6D"
  },
  {
    "path": "dataset/Training_C.csv",
    "chars": 5,
    "preview": "name\n"
  },
  {
    "path": "dataset/Training_D.csv",
    "chars": 5,
    "preview": "name\n"
  },
  {
    "path": "dataset/info.csv",
    "chars": 7868,
    "preview": "External code,VendorName,Vendor,Centre,ED,ES\nA0S9V9,Siemens,A,1,0,9\nA1D0Q7,Philips,B,2,0,9\nA1D9Z7,Siemens,A,1,22,11\nA1E9"
  },
  {
    "path": "dataset_domain.py",
    "chars": 7311,
    "preview": "import os\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.utils.data im"
  },
  {
    "path": "losses.py",
    "chars": 3414,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport pdb \n\nclass DiceLoss(nn.Mo"
  },
  {
    "path": "model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "model/conv_trans_utils.py",
    "chars": 17515,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport pdb\n\n\n\ndef conv3x"
  },
  {
    "path": "model/resnet_utnet.py",
    "chars": 2456,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .unet_utils import up_block\nfrom .transunet imp"
  },
  {
    "path": "model/swin_unet.py",
    "chars": 36793,
    "preview": "# code is borrowed from the original repo and fit into our training framework\n# https://github.com/HuCaoFighting/Swin-Un"
  },
  {
    "path": "model/transunet.py",
    "chars": 28847,
    "preview": "# The code is borrowed from original repo: https://github.com/Beckschen/TransUNet/tree/d68a53a2da73ecb496bb7585340eb660e"
  },
  {
    "path": "model/unet_utils.py",
    "chars": 6989,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport pdb \n\nclass DoubleConv(nn.Module):\n\n    \"\"\"(co"
  },
  {
    "path": "model/utnet.py",
    "chars": 9492,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .unet_utils import up_block, down_block\nfrom .conv_trans_utils import *\n\nimport"
  },
  {
    "path": "train_deep.py",
    "chars": 12648,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import optim\n\nimport numpy as np\nfrom mode"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/lookup_tables.py",
    "chars": 23145,
    "preview": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n"
  },
  {
    "path": "utils/metrics.py",
    "chars": 18805,
    "preview": "# Copyright 2018 Google Inc. All Rights Reserved.\n\n#\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n"
  },
  {
    "path": "utils/utils.py",
    "chars": 4233,
    "preview": "import torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport SimpleITK as sitk\nimport pd"
  }
]

About this extraction

This page contains the full source code of the yhygao/UTNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 27 files (190.5 KB), approximately 63.2k tokens, and a symbol index with 230 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!