Repository: MIC-DKFZ/basic_unet_example
Branch: master
Commit: 063353b31517
Files: 45
Total size: 149.7 KB
Directory structure:
gitextract_9r9wvxz8/
├── .gitignore
├── LICENSE
├── Readme.md
├── __init__.py
├── configs/
│ ├── Config_unet.py
│ ├── Config_unet_spleen.py
│ └── __init__.py
├── datasets/
│ ├── __init__.py
│ ├── data_loader.py
│ ├── example_dataset/
│ │ ├── __init__.py
│ │ ├── create_splits.py
│ │ └── preprocessing.py
│ ├── spleen/
│ │ ├── __init__.py
│ │ ├── create_splits.py
│ │ └── preprocessing.py
│ ├── three_dim/
│ │ ├── NumpyDataLoader.py
│ │ ├── __init__.py
│ │ └── data_augmentation.py
│ ├── two_dim/
│ │ ├── NumpyDataLoader.py
│ │ ├── __init__.py
│ │ └── data_augmentation.py
│ └── utils.py
├── evaluation/
│ ├── __init__.py
│ ├── evaluator.py
│ ├── metrics.py
│ └── readme.md
├── experiments/
│ ├── UNetExperiment.py
│ ├── UNetExperiment3D.py
│ └── __init__.py
├── loss_functions/
│ ├── ND_Crossentropy.py
│ ├── __init__.py
│ ├── dice_loss.py
│ └── topk_loss.py
├── networks/
│ ├── RecursiveUNet.py
│ ├── RecursiveUNet3D.py
│ └── UNET.py
├── requirements.txt
├── run_preprocessing.py
├── run_train_pipeline.py
├── runner.py
├── segment_a_spleen.py
├── train.py
├── train3D.py
└── utilities/
├── __init__.py
└── file_and_folder_operations.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
*.pyc
.DS_Store
*.egg-info
.pytest_cache/*
data
output_experiment
venv
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Readme.md
================================================
# Basic U-Net example by MIC@DKFZ
Copyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). Please make sure that your usage of this code is in compliance with the code license:
[](https://github.com/MIC-DKFZ/basic_unet_example/blob/master/LICENSE)
This python code is an example project of how to use a U-Net [1] for segmentation on medical images using PyTorch (https://www.pytorch.org).
It was developed at the Division of Medical Image Computing at the German Cancer Research Center (DKFZ).
It is also an example of how to use our other python packages batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) and
Trixi (https://github.com/MIC-DKFZ/trixi) [2] to suit all our deep learning data augmentation needs.
If you have any questions or issues or you encounter a bug, feel free to contact us, open a GitHub issue or ask the community on Gitter:
[](https://gitter.im/basic-Unet/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
> **WARNING**: This repo was implemented and tested on Linux. We highly recommend using it within a Linux environment. If you use Windows you might experience some issues (see
> section "Errors and how to handle them")
## How to set it up
The example is very easy to use. Just create a new virtual environment in python and install the requirements.
This example requires python3. It was implemented with python 3.5.
> **WARNING**: The newest supported version is python 3.7.9. For newer python versions there are some requirements that are not available in the needed version.
```
pip3 install -r requirements.txt
```
In this example code, we show how to use visdom for live visualization. See the Trixi documentation for more details or information about other tools like tensorboard.
After setting up the virtual environment you have to start visdom once so it can download some needed files. You only
have to do that once. You can stop the visdom server after a few seconds when it finished downloading the files.
```
python3 -m visdom.server
```
You can edit the paths for data storage and logging in the config file. By default, everything is stored in your working directory.
## How to use it
To start the training simply run
```
python3 run_train_pipeline.py
```
This will download the Hippocampus dataset from the medical segmentation decathlon (http://medicaldecathlon.com),
extract and preprocess it and then start the training. The preprocessing loads the images (imagesTr) and the corresponding labels (labelsTr), performs some normalization and padding operations and saves the data as NPY files. The available images are then split into `train`, `validation` and `test` sets.
The splits are saved to a `splits.pkl` file. The images in `imagesTs` are not used in the example, because they are the test set for the medical segmentation decathlon and
therefore no ground truth is provided.
If you run the pipeline again, the dataset will not be downloaded, extracted or preprocessed again. To enforce it, just delete the folder.
The training process will automatically be visualized using trixi/visdom. After starting the training you navigate in your browser to the port which is printed by the training script. Then you should see your loss curve and so on.
By default, a 2-dimensional U-Net is used. The example also comes with a 3-D version of the network (Özgün Cicek et al.).
To use the 3-D version, simple use
```
python train3D.py
```
> **WARNING**: The 3-D version is not yet tested thoroughly. Use it with caution!
## How to use it for your own data
This description is work in progress. If you use this repo for your own data please share your experience, so we can update this part.
### Config
The included `Config_unet.py` is an example config file. You have to adapt this to fit your local environment, e.g., if you run out of CUDA memory, try to reduce `batch_size` or
`patch_size`. All the other parameters should be self-explanatory or described directly in the code comments.
Choose the `#Train parameters` to fit both, your data and your workstation.
With `fold` you can choose which split from your `splits.pkl` you want to use for the training.
You may also need to adapt the paths (`data_root_dir, data_dir, data_test_dir and split_dir`).
You can change the `Logging parameters` if you want to. With `append_rnd_string`, you can give each experiment you start a unique name.
If you want to start your visdom server manually, just set `start_visdom=False`. If you do not want to use visdom logging at all, just remove the visdom logger from your
experiment, e.g. `run_train_pipeline.py` line 47:
```
loggers={
"visdom": ("visdom", {"auto_start": c.start_visdom})
}
```
### Datasets
If you want to use the provided DataLoader, you need to preprocess your data appropriately. An example can be found in the
"example_dataset" folder. Make sure to load your images and your labels as numpy arrays. The required shape is `(#slices, w,h)`.
Then save both using:
```
result = np.stack((image, label))
np.save(output_filename, result)
```
The provided DataLoader requires a splits.pkl file, that contains a dictionary of all the files used for training, validation and testing.
It looks like this:
```
[{'train': ['dataset_name_1',...], 'val': ['dataset_name_2', ...], 'test': ['dataset_name_3', ...]}]
```
We use the `MIC/batchgenerators` to perform data augmentation. The example uses cropping, mirroring and some elastic spatial transformation.
You can change the data augmentation by editing the `data_augmentation.py`. Please see the `MIC/batchgenerators` documentation for more details.
To train your network, simply run
```
python train.py
```
You can either edit the config file or add command line parameters like this:
```
python train.py --n_epochs 100 [...]
```
## Networks
This example contains a simple implementation of the U-Net [1], which can be found in `networks>UNET.py`.
A little more generic version of the U-Net, as well as the 3D U-Net [3], can be found in `networks>RecursiveUNet.py`
respectively `networks>RecursiveUNet3D.py`. This implementation is done recursively.
It is therefore very easy to configure the number of downsamplings. Also, the type of normalization can be passed as a parameter (default is nn.InstanceNorm2d).
## Errors and how to handle them
In this section, we want to collect common errors that may occur when using this repository.
If you encounter something, feel free to let us know about it and we will include it here.
### Windows related issues
If you want to use this repo on Windows, please note, that you have to adapt to some things.
We recommend to install PyTorch via conda on Windows using: `python -m conda install pytorch torchvision cpuonly -c pytorch`
You then have to remove torch from the requirements.txt.
If you run into issues like the following one:
```
AttributeError: Can't pickle local object 'MultiThreadedDataLoader.get_worker_init_fn.<locals>.init_fn'`
```
try to use SingleProcessDataLoader instead. This error is probably caused by how multithreading is handled in python on Windows.
So fix this, add `num_processes=0` to your dataloaders:
```
self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size,
batch_size=self.config.batch_size, keys=tr_keys, num_processes=0)
self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size,
batch_size=self.config.batch_size, keys=val_keys, mode="val", do_reshuffle=False, num_processes=0)
self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size,
batch_size=self.config.batch_size, keys=test_keys, mode="test", do_reshuffle=False, num_processes=0)
```
### Multiple Labels
Depending on your dataset you might be dealing with multiple labels. For example the
data from BRATS (https://www.med.upenn.edu/sbia/brats2017.html) has the following labels:
```
"labels": {
"0": "background",
"1": "edema",
"2": "non-enhancing tumor",
"3": "enhancing tumour"
},
```
* If you run into an error like this:
```
Experiment exited. Checkpoints stored =)
INFO:default-z3HafHO4CS:Experiment exited. Checkpoints stored =)
Unhandled exception in thread started by <function PytorchExperimentLogger.save_checkpoint_static at 0x7fd07c3e8510>
Traceback (most recent call last):
File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 196, in save_checkpoint_static
torch.save(to_cpu(kwargs), checkpoint_file)
File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in to_cpu
return {key: to_cpu(val) for key, val in obj.items()}
File "//python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in <dictcomp>
return {key: to_cpu(val) for key, val in obj.items()}
File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in to_cpu
return {key: to_cpu(val) for key, val in obj.items()}
File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in <dictcomp>
return {key: to_cpu(val) for key, val in obj.items()}
File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 189, in to_cpu
return obj.cpu()
RuntimeError: CUDA error: device-side assert triggered
```
make sure you updated `num_classes` in your config file. The value of `num_classes` should always
equal the number of your labels including background.
* If you run into an error like this:
```
File "/home/student/basic_unet/trixi/trixi/experiment/experiment.py", line 108, in run
self.process_err(e)
File "/home/student/basic_unet/trixi/trixi/experiment/pytorchexperiment.py", line 391, in process_err
raise e
File "/home/student/basic_unet/trixi/trixi/experiment/experiment.py", line 89, in run
self.train(epoch=self._epoch_idx)
File "/home/student/PycharmProjects/new_unet/experiments/UNetExperiment.py", line 113, in train
loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
File "/opt/anaconda3/envs/a_new_test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(input, *kwargs)
File "/home/student/PycharmProjects/new_unet/loss_functions/dice_loss.py", line 125, in forward
yonehot.scatter(1, y, 1)
RuntimeError: Invalid index in scatter at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:551
```
make sure to check your labels again. The error may be caused by the fact that the labels are not sequential. This causes `scatter` to crash. Consider changing the values of your labels.
## References
[1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation."
International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
[2] David Zimmerer, Jens Petersen, GregorKoehler, Jakob Wasserthal, dzimmm, Tim, … André Pequeño. (2018, November 23). MIC-DKFZ/trixi: Alpha (Version v0.1.1).
Zenodo. http://doi.org/10.5281/zenodo.1495180
[3] Çiçek, Özgün, et al. "3D U-Net: learning dense volumetric segmentation from sparse annotation."
International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.
================================================
FILE: __init__.py
================================================
================================================
FILE: configs/Config_unet.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from trixi.util import Config
def get_config():
# Set your own path, if needed.
data_root_dir = os.path.abspath('data') # The path where the downloaded dataset is stored.
c = Config(
update_from_argv=True, # If set 'True', it allows to update each configuration by a cmd/terminal parameter.
# Train parameters
num_classes=3,
in_channels=1,
batch_size=8,
patch_size=64,
n_epochs=10,
learning_rate=0.0002,
fold=0, # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
device="cuda", # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html
# Logging parameters
name='Basic_Unet',
author='kleina', # Author of this project
plot_freq=10, # How often should stuff be shown in visdom
append_rnd_string=False, # Appends a random string to the experiment name to make it unique.
start_visdom=True, # You can either start a visom server manually or have trixi start it for you.
do_instancenorm=True, # Defines whether or not the UNet does a instance normalization in the contracting path
do_load_checkpoint=False,
checkpoint_dir='',
# Adapt to your own path, if needed.
google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C', # This id is used to download the example dataset.
dataset_name='Task04_Hippocampus',
base_dir=os.path.abspath('output_experiment'), # Where to log the output of the experiment.
data_root_dir=data_root_dir, # The path where the downloaded dataset is stored.
data_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'), # This is where your training and validation data is stored
data_test_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'), # This is where your test data is stored
split_dir=os.path.join(data_root_dir, 'Task04_Hippocampus'), # This is where the 'splits.pkl' file is located, that holds your splits.
# execute a segmentation process on a specific image using the model
model_dir=os.path.join(os.path.abspath('output_experiment'), ''), # the model being used for segmentation
)
print(c)
return c
================================================
FILE: configs/Config_unet_spleen.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from trixi.util import Config
def get_config():
# Set your own path, if needed.
data_root_dir = os.path.abspath('data') # The path where the downloaded dataset is stored.
c = Config(
update_from_argv=True, # If set 'True', it allows to update each configuration by a cmd/terminal parameter.
# Train parameters
num_classes=2,
in_channels=1,
batch_size=3, # works with 6 on GB GPU
patch_size=512,
n_epochs=1,
learning_rate=0.0002,
fold=0, # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
device="cuda", # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html
# Logging parameters
name='Basic_Unet',
author='kleina', # Author of this project
plot_freq=10, # How often should stuff be shown in visdom
append_rnd_string=False, # Appends a random string to the experiment name to make it unique.
start_visdom=True, # You can either start a visom server manually or have trixi start it for you.
do_instancenorm=True, # Defines whether or not the UNet does a instance normalization in the contracting path
do_load_checkpoint=False,
checkpoint_dir='',
# Adapt to your own path, if needed.
google_drive_id='1jzeNU1EKnK81PyTsrx0ujfNl-t0Jo8uE', #spleen
dataset_name='Task09_Spleen',
base_dir=os.path.abspath('output_experiment'), # Where to log the output of the experiment.
data_root_dir=data_root_dir, # The path where the downloaded dataset is stored.
data_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'), # This is where your training and validation data is stored
data_test_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'), # This is where your test data is stored
split_dir=os.path.join(data_root_dir, 'Task09_Spleen'), # This is where the 'splits.pkl' file is located, that holds your splits.
# execute a segmentation process on a specific image using the model
model_dir=os.path.join(os.path.abspath('output_experiment'), '20200108-035420_Basic_Unet/checkpoint/checkpoint_current'), # the model being used for segmentation
)
print(c)
return c
================================================
FILE: configs/__init__.py
================================================
================================================
FILE: datasets/__init__.py
================================================
================================================
FILE: datasets/data_loader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import DataLoader, Dataset
from trixi.util.pytorchutils import set_seed
class WrappedDataset(Dataset):
def __init__(self, dataset, transform):
self.transform = transform
self.dataset = dataset
self.is_indexable = False
if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next is True):
self.is_indexable = True
def __getitem__(self, index):
if not self.is_indexable:
item = next(self.dataset)
else:
item = self.dataset[index]
item = self.transform(**item)
return item
def __len__(self):
return int(self.dataset.num_batches)
class MultiThreadedDataLoader(object):
def __init__(self, data_loader, transform, num_processes, **kwargs):
self.cntr = 1
self.ds_wrapper = WrappedDataset(data_loader, transform)
self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=num_processes, pin_memory=True, drop_last=False,
worker_init_fn=self.get_worker_init_fn())
self.num_processes = num_processes
self.iter = None
def get_worker_init_fn(self):
def init_fn(worker_id):
set_seed(worker_id + self.cntr)
return init_fn
def __iter__(self):
self.kill_iterator()
self.iter = iter(self.generator)
return self.iter
def __next__(self):
if self.iter is None:
self.iter = iter(self.generator)
return next(self.iter)
def renew(self):
self.cntr += 1
self.kill_iterator()
self.generator.worker_init_fn = self.get_worker_init_fn()
self.iter = iter(self.generator)
def restart(self):
pass
# self.iter = iter(self.generator)
def kill_iterator(self):
try:
if self.iter is not None:
self.iter._shutdown_workers()
for p in self.iter.workers:
p.terminate()
except:
print("Could not kill Dataloader Iterator")
================================================
FILE: datasets/example_dataset/__init__.py
================================================
================================================
FILE: datasets/example_dataset/create_splits.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from utilities.file_and_folder_operations import subfiles
import os
import random
def create_splits(output_dir, image_dir):
"""File to split the dataset into multiple folds and the train, validation and test set.
:param output_dir: Directory to write the splits file to
:param image_dir: Directory where the images lie in.
"""
npy_files = subfiles(image_dir, suffix=".npy", join=False)
sample_size = len(npy_files)
testset_size = int(sample_size * 0.25)
valset_size = int(sample_size * 0.25)
trainset_size = sample_size - valset_size - testset_size # Assure all samples are used.
if sample_size < (trainset_size + valset_size + testset_size):
raise ValueError("Assure more total samples exist than train test and val samples combined!")
splits = []
sample_set = {sample[:-4] for sample in npy_files.copy()} # Remove the file extension
test_samples = random.sample(sample_set, testset_size) # IMO the Testset should be static for all splits
for split in range(0, 5):
train_samples = random.sample(sample_set - set(test_samples), trainset_size)
val_samples = list(sample_set - set(train_samples) - set(test_samples))
train_samples.sort()
val_samples.sort()
split_dict = dict()
split_dict['train'] = train_samples
split_dict['val'] = val_samples
split_dict['test'] = test_samples
splits.append(split_dict)
# Todo: IMO it is better to write that dict as JSON. This (unlike pickle) allows the user to inspect the file with an editor
with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:
pickle.dump(splits, f)
splits_sanity_check(output_dir)
# ToDo: The naming "splits.pkl should not be distributed over multiple files. This makes changing of it less clear.
# Instead move saving and loading to one file. (Here would be a good place)
# Other usages are: spleen/create_splits.py:57 (Which is redundand anyways?);
# UNetExperiment3D.py:55 and UNetExperiment.py:55
def splits_sanity_check(path):
""" Takes path to a splits file and verifies that no samples from the test dataset leaked into train or validation.
:param path
"""
with open(os.path.join(path, 'splits.pkl'), 'rb') as f:
splits = pickle.load(f)
for i in range(len(splits)):
samples = splits[i]
tr_samples = set(samples["train"])
vl_samples = set(samples["val"])
ts_samples = set(samples["test"])
assert len(tr_samples.intersection(vl_samples)) == 0, "Train and validation samples overlap!"
assert len(vl_samples.intersection(ts_samples)) == 0, "Validation and Test samples overlap!"
assert len(tr_samples.intersection(ts_samples)) == 0, "Train and Test samples overlap!"
return
================================================
FILE: datasets/example_dataset/preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from batchgenerators.augmentations.utils import pad_nd_image
from medpy.io import load
import os
import numpy as np
import torch
from utilities.file_and_folder_operations import subfiles
def preprocess_data(root_dir, y_shape=64, z_shape=64):
image_dir = os.path.join(root_dir, 'imagesTr')
label_dir = os.path.join(root_dir, 'labelsTr')
output_dir = os.path.join(root_dir, 'preprocessed')
classes = 3
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print('Created' + output_dir + '...')
class_stats = defaultdict(int)
total = 0
nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)
for i in range(0, len(nii_files)):
if nii_files[i].startswith("._"):
nii_files[i] = nii_files[i][2:]
for f in nii_files:
image, _ = load(os.path.join(image_dir, f))
label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
print(f)
for i in range(classes):
class_stats[i] += np.sum(label == i)
total += np.sum(label == i)
# normalize images
image = (image - image.min())/(image.max()-image.min())
image = pad_nd_image(image, (image.shape[0], y_shape, z_shape), "constant", kwargs={'constant_values': image.min()})
label = pad_nd_image(label, (image.shape[0], y_shape, z_shape), "constant", kwargs={'constant_values': label.min()})
result = np.stack((image, label))
np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)
print(f)
print(total)
for i in range(classes):
print(class_stats[i], class_stats[i]/total)
def preprocess_single_file(image_file):
image, image_header = load(image_file)
image = (image - image.min()) / (image.max() - image.min())
data = np.expand_dims(image, 1)
return torch.from_numpy(data), image_header
def postprocess_single_image(image):
# desired shape is [b w h]
result_converted = image[::, 0, ::, ::]
result_mapped = [i * 255 for i in result_converted]
return result_mapped
================================================
FILE: datasets/spleen/__init__.py
================================================
================================================
FILE: datasets/spleen/create_splits.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from utilities.file_and_folder_operations import subfiles
import os
import numpy as np
def create_splits(output_dir, image_dir):
npy_files = subfiles(image_dir, suffix=".npy", join=False)
trainset_size = len(npy_files)*50//100
valset_size = len(npy_files)*25//100
testset_size = len(npy_files)*25//100
splits = []
for split in range(0, 5):
image_list = npy_files.copy()
trainset = []
valset = []
testset = []
for i in range(0, trainset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
trainset.append(patient[:-4])
for i in range(0, valset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
valset.append(patient[:-4])
for i in range(0, testset_size):
patient = np.random.choice(image_list)
image_list.remove(patient)
testset.append(patient[:-4])
split_dict = dict()
split_dict['train'] = trainset
split_dict['val'] = valset
split_dict['test'] = testset
splits.append(split_dict)
with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:
pickle.dump(splits, f)
================================================
FILE: datasets/spleen/preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from medpy.io import load
import os
import numpy as np
from utilities.file_and_folder_operations import subfiles
import torch
def preprocess_data(root_dir, y_shape=64, z_shape=64):
image_dir = os.path.join(root_dir, 'imagesTr')
label_dir = os.path.join(root_dir, 'labelsTr')
output_dir = os.path.join(root_dir, 'preprocessed')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print('Created' + output_dir + '...')
class_stats = defaultdict(int)
total = 0
nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)
for i in range(0, len(nii_files)):
if nii_files[i].startswith("._"):
nii_files[i] = nii_files[i][2:]
for f in nii_files:
image, _ = load(os.path.join(image_dir, f))
label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))
print(f)
# normalize images
image = (image - image.min())/(image.max()-image.min())
image = np.swapaxes(image, 0, 2)
image = np.swapaxes(image, 1, 2)
label = np.swapaxes(label, 0, 2)
label = np.swapaxes(label, 1, 2)
result = np.stack((image, label))
np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)
print(f)
print(total)
def preprocess_single_file(image_file):
image, image_header = load(image_file)
image = (image - image.min()) / (image.max() - image.min())
image = np.swapaxes(image, 0, 2)
image = np.swapaxes(image, 1, 2)
# TODO check original shape and reshape data if necessary
# image = reshape(image, append_value=0, new_shape=(image.shape[0], y_shape, z_shape))
# numpy_array = np.array(image)
# Image shape is [b, w, h] and has one channel only
# Desired shape = [b, c, w, h]
# --> expand to have only one channel c=1 - data is in desired shape
data = np.expand_dims(image, 1)
return torch.from_numpy(data), image_header
def postprocess_single_image(image):
# desired shape is [b w h]
result_converted = image[::, 0, ::, ::]
result_mapped = [i * 255 for i in result_converted]
# swap axes back, like we were supposed to do so
result_mapped = np.swapaxes(result_mapped, 2, 1)
result_mapped = np.swapaxes(result_mapped, 2, 0)
return result_mapped
================================================
FILE: datasets/three_dim/NumpyDataLoader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fnmatch
import random
import numpy as np
from batchgenerators.dataloading import SlimDataLoaderBase
from datasets.data_loader import MultiThreadedDataLoader
from .data_augmentation import get_transforms
def load_dataset(base_dir, pattern='*.npy', keys=None):
fls = []
files_len = []
dataset = []
for root, dirs, files in os.walk(base_dir):
i = 0
for filename in sorted(fnmatch.filter(files, pattern)):
if keys is not None and filename[:-4] in keys:
npy_file = os.path.join(root, filename)
numpy_array = np.load(npy_file, mmap_mode="r")
fls.append(npy_file)
files_len.append(numpy_array.shape[1])
dataset.extend([i])
i += 1
return fls, files_len, dataset
class NumpyDataSet(object):
"""
TODO
"""
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
file_pattern='*.npy', label=1, input=(0,), do_reshuffle=True, keys=None):
data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern,
input=input, label=label, keys=keys)
self.data_loader = data_loader
self.batch_size = batch_size
self.do_reshuffle = do_reshuffle
self.number_of_slices = 1
self.transforms = get_transforms(mode=mode, target_size=target_size)
self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,
num_cached_per_queue=num_cached_per_queue, seeds=seed,
shuffle=do_reshuffle)
self.augmenter.restart()
def __len__(self):
return len(self.data_loader)
def __iter__(self):
if self.do_reshuffle:
self.data_loader.reshuffle()
self.augmenter.renew()
return self.augmenter
def __next__(self):
return next(self.augmenter)
class NumpyDataLoader(SlimDataLoaderBase):
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
seed=None, file_pattern='*.npy', label=1, input=(0,), keys=None):
self.files, self.file_len, self.dataset = load_dataset(base_dir=base_dir, pattern=file_pattern, keys=keys, )
super(NumpyDataLoader, self).__init__(self.dataset, batch_size, num_batches)
self.batch_size = batch_size
self.use_next = False
if mode == "train":
self.use_next = False
self.idxs = list(range(0, len(self.dataset)))
self.data_len = len(self.dataset)
self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)
if isinstance(label, int):
label = (label,)
self.input = input
self.label = label
self.np_data = np.asarray(self.dataset)
def reshuffle(self):
print("Reshuffle...")
random.shuffle(self.idxs)
print("Initializing... this might take a while...")
def generate_train_batch(self):
open_arr = random.sample(self._data, self.batch_size)
return self.get_data_from_array(open_arr)
def __len__(self):
n_items = min(self.data_len // self.batch_size, self.num_batches)
return n_items
def __getitem__(self, item):
idxs = self.idxs
data_len = len(self.dataset)
np_data = self.np_data
if item > len(self):
raise StopIteration()
if (item * self.batch_size) == data_len:
raise StopIteration()
start_idx = (item * self.batch_size) % data_len
stop_idx = ((item + 1) * self.batch_size) % data_len
if ((item + 1) * self.batch_size) == data_len:
stop_idx = data_len
if stop_idx > start_idx:
idxs = idxs[start_idx:stop_idx]
else:
raise StopIteration()
open_arr = np_data[idxs]
return self.get_data_from_array(open_arr)
def get_data_from_array(self, open_array):
data = []
fnames = []
idxs = []
labels = []
for idx in open_array:
fn_name = self.files[idx]
numpy_array = np.load(fn_name, mmap_mode="r")
data.append(numpy_array[list(self.input)]) # 'None' keeps the dimension
if self.label is not None:
labels.append(numpy_array[list(self.input)]) # 'None' keeps the dimension
fnames.append(self.files[idx])
idxs.append(idx)
ret_dict = {'data': data, 'fnames': fnames, 'idxs': idxs}
if self.label is not None:
ret_dict['seg'] = labels
return ret_dict
================================================
FILE: datasets/three_dim/__init__.py
================================================
================================================
FILE: datasets/three_dim/data_augmentation.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.transforms import Compose, MirrorTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform
from batchgenerators.transforms.utility_transforms import NumpyToTensor
def get_transforms(mode="train", target_size=128):
transform_list = []
if mode == "train":
transform_list = [CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=target_size, order=1),
MirrorTransform(axes=(2,)),
SpatialTransform(patch_size=(target_size, target_size, target_size), random_crop=False,
patch_center_dist_from_border=target_size // 2,
do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
do_rotation=True,
angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
scale=(0.9, 1.4),
border_mode_data="nearest", border_mode_seg="nearest"),
]
elif mode == "val":
transform_list = [CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=target_size, order=1),
]
elif mode == "test":
transform_list = [CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=target_size, order=1),
]
transform_list.append(NumpyToTensor())
return Compose(transform_list)
================================================
FILE: datasets/two_dim/NumpyDataLoader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fnmatch
import random
import numpy as np
from batchgenerators.dataloading import SlimDataLoaderBase
from datasets.data_loader import MultiThreadedDataLoader
from .data_augmentation import get_transforms
def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):
fls = []
files_len = []
slices_ax = []
for root, dirs, files in os.walk(base_dir):
i = 0
for filename in sorted(fnmatch.filter(files, pattern)):
if keys is not None and filename[:-4] in keys:
npy_file = os.path.join(root, filename)
numpy_array = np.load(npy_file, mmap_mode="r")
fls.append(npy_file)
files_len.append(numpy_array.shape[1])
slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)])
i += 1
return fls, files_len, slices_ax,
class NumpyDataSet(object):
"""
TODO
"""
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None):
data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, file_pattern=file_pattern,
input_slice=input_slice, label_slice=label_slice, keys=keys)
self.data_loader = data_loader
self.batch_size = batch_size
self.do_reshuffle = do_reshuffle
self.number_of_slices = 1
self.transforms = get_transforms(mode=mode, target_size=target_size)
self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,
num_cached_per_queue=num_cached_per_queue,
shuffle=do_reshuffle)
self.augmenter.restart()
def __len__(self):
return len(self.data_loader)
def __iter__(self):
if self.do_reshuffle:
self.data_loader.reshuffle()
self.augmenter.renew()
return self.augmenter
def __next__(self):
return next(self.augmenter)
class NumpyDataLoader(SlimDataLoaderBase):
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None):
self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, )
super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches)
self.batch_size = batch_size
self.use_next = False
if mode == "train":
self.use_next = False
self.slice_idxs = list(range(0, len(self.slices)))
self.data_len = len(self.slices)
self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)
if isinstance(label_slice, int):
label_slice = (label_slice,)
self.input_slice = input_slice
self.label_slice = label_slice
self.np_data = np.asarray(self.slices)
def reshuffle(self):
print("Reshuffle...")
random.shuffle(self.slice_idxs)
print("Initializing... this might take a while...")
def generate_train_batch(self):
open_arr = random.sample(self._data, self.batch_size)
return self.get_data_from_array(open_arr)
def __len__(self):
n_items = min(self.data_len // self.batch_size, self.num_batches)
return n_items
def __getitem__(self, item):
slice_idxs = self.slice_idxs
data_len = len(self.slices)
np_data = self.np_data
if item > len(self):
raise StopIteration()
if (item * self.batch_size) == data_len:
raise StopIteration()
start_idx = (item * self.batch_size) % data_len
stop_idx = ((item + 1) * self.batch_size) % data_len
if ((item + 1) * self.batch_size) == data_len:
stop_idx = data_len
if stop_idx > start_idx:
idxs = slice_idxs[start_idx:stop_idx]
else:
raise StopIteration()
open_arr = np_data[idxs]
return self.get_data_from_array(open_arr)
def get_data_from_array(self, open_array):
data = []
fnames = []
slice_idxs = []
labels = []
for slice in open_array:
fn_name = self.files[slice[0]]
numpy_array = np.load(fn_name, mmap_mode="r")
numpy_slice = numpy_array[:, slice[1], ]
data.append(numpy_slice[list(self.input_slice)]) # 'None' keeps the dimension
if self.label_slice is not None:
labels.append(numpy_slice[list(self.label_slice)]) # 'None' keeps the dimension
fnames.append(self.files[slice[0]])
slice_idxs.append(slice[1])
ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs}
if self.label_slice is not None:
ret_dict['seg'] = np.asarray(labels)
return ret_dict
================================================
FILE: datasets/two_dim/__init__.py
================================================
================================================
FILE: datasets/two_dim/data_augmentation.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.transforms import Compose, MirrorTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform
from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform
from batchgenerators.transforms.utility_transforms import NumpyToTensor
import numpy as np
def get_transforms(mode="train", target_size=128):
tranform_list = []
if mode == "train":
tranform_list = [# CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=(target_size,target_size), order=1),
MirrorTransform(axes=(1,)),
SpatialTransform(patch_size=(target_size, target_size), random_crop=False,
patch_center_dist_from_border=target_size // 2,
do_elastic_deform=True, alpha=(0., 900.), sigma=(20., 30.),
do_rotation=True, p_rot_per_sample=0.8,
angle_x=(-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
scale=(0.85, 1.25), p_scale_per_sample=0.8,
border_mode_data="nearest", border_mode_seg="nearest"),
]
elif mode == "val":
tranform_list = [# CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=target_size, order=1),
]
elif mode == "test":
tranform_list = [# CenterCropTransform(crop_size=target_size),
ResizeTransform(target_size=target_size, order=1),
]
tranform_list.append(NumpyToTensor())
return Compose(tranform_list)
================================================
FILE: datasets/utils.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import exists
import tarfile
from google_drive_downloader import GoogleDriveDownloader as gdd
def download_dataset(dest_path, dataset, id=''):
if not exists(os.path.join(dest_path, dataset)):
tar_path = os.path.join(dest_path, dataset) + '.tar'
gdd.download_file_from_google_drive(file_id=id,
dest_path=tar_path, overwrite=False,
unzip=False)
print('Extracting data [STARTED]')
tar = tarfile.open(tar_path)
tar.extractall(dest_path)
print('Extracting data [DONE]')
else:
print('Data already downloaded. Files are not extracted again.')
print('Data already downloaded. Files are not extracted again.')
return
================================================
FILE: evaluation/__init__.py
================================================
================================================
FILE: evaluation/evaluator.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import collections
import inspect
import json
import hashlib
from datetime import datetime
import numpy as np
import pandas as pd
import SimpleITK as sitk
from evaluation.metrics import ConfusionMatrix, ALL_METRICS
class Evaluator:
"""Object that holds test and reference segmentations with label information
and computes a number of metrics on the two. 'labels' must either be an
iterable of numeric values (or tuples thereof) or a dictionary with string
names and numeric values.
"""
default_metrics = [
"False Positive Rate",
"Dice",
"Jaccard",
"Precision",
"Recall",
"Accuracy",
"False Omission Rate",
"Negative Predictive Value",
"False Negative Rate",
"True Negative Rate",
"False Discovery Rate",
"Total Positives Test",
"Total Positives Reference"
]
default_advanced_metrics = [
"Hausdorff Distance",
"Hausdorff Distance 95",
"Avg. Surface Distance",
"Avg. Symmetric Surface Distance"
]
def __init__(self,
test=None,
reference=None,
labels=None,
metrics=None,
advanced_metrics=None,
nan_for_nonexisting=True):
self.test = None
self.reference = None
self.confusion_matrix = ConfusionMatrix()
self.labels = None
self.nan_for_nonexisting = nan_for_nonexisting
self.result = None
self.metrics = []
if metrics is None:
for m in self.default_metrics:
self.metrics.append(m)
else:
for m in metrics:
self.metrics.append(m)
self.advanced_metrics = []
if advanced_metrics is None:
for m in self.default_advanced_metrics:
self.advanced_metrics.append(m)
else:
for m in advanced_metrics:
self.advanced_metrics.append(m)
self.set_reference(reference)
self.set_test(test)
if labels is not None:
self.set_labels(labels)
else:
if test is not None and reference is not None:
self.construct_labels()
def set_test(self, test):
"""Set the test segmentation."""
self.test = test
def set_reference(self, reference):
"""Set the reference segmentation."""
self.reference = reference
def set_labels(self, labels):
"""Set the labels.
:param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels
will only have names if you pass a dictionary"""
if not isinstance(labels, (dict, set, list, tuple)):
raise ValueError("Labels must be either list, tuple, set or dict")
elif isinstance(labels, dict):
self.labels = collections.OrderedDict(labels)
elif isinstance(labels, set):
self.labels = list(labels)
elif isinstance(labels, (list, tuple)):
self.labels = labels
else:
raise TypeError("Can only handle dict, list, tuple, set & numpy array, but input is of type {}".format(type(labels)))
def construct_labels(self):
"""Construct label set from unique entries in segmentations."""
if self.test is None and self.reference is None:
raise ValueError("No test or reference segmentations.")
elif self.test is None:
labels = np.unique(self.reference)
else:
labels = np.union1d(np.unique(self.test),
np.unique(self.reference))
self.labels = list(map(lambda x: int(x), labels))
def set_metrics(self, metrics):
"""Set evaluation metrics"""
if isinstance(metrics, set):
self.metrics = list(metrics)
elif isinstance(metrics, (list, tuple, np.ndarray)):
self.metrics = metrics
else:
raise TypeError("Can only handle list, tuple, set & numpy array, but input is of type {}".format(type(metrics)))
def add_metric(self, metric):
if metric not in self.metrics:
self.metrics.append(metric)
def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs):
"""Compute metrics for segmentations."""
if test is not None:
self.set_test(test)
if reference is not None:
self.set_reference(reference)
if self.test is None or self.reference is None:
raise ValueError("Need both test and reference segmentations.")
if self.labels is None:
self.construct_labels()
self.metrics.sort()
# get functions for evaluation
# somewhat convoluted, but allows users to define additonal metrics
# on the fly, e.g. inside an IPython console
_funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics}
frames = inspect.getouterframes(inspect.currentframe())
for metric in self.metrics:
for f in frames:
if metric in f[0].f_locals:
_funcs[metric] = f[0].f_locals[metric]
break
else:
if metric in _funcs:
continue
else:
raise NotImplementedError(
"Metric {} not implemented.".format(metric))
# get results
self.result = {}
eval_metrics = self.metrics
if advanced:
eval_metrics += self.advanced_metrics
if isinstance(self.labels, dict):
for label, name in self.labels.items():
self.result[name] = {}
if not hasattr(label, "__iter__"):
self.confusion_matrix.set_test(self.test == label)
self.confusion_matrix.set_reference(self.reference == label)
else:
current_test = 0
current_reference = 0
for l in label:
current_test += (self.test == l)
current_reference += (self.reference == l)
self.confusion_matrix.set_test(current_test)
self.confusion_matrix.set_reference(current_reference)
for metric in eval_metrics:
self.result[name][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
nan_for_nonexisting=self.nan_for_nonexisting,
**metric_kwargs)
else:
for i, l in enumerate(self.labels):
self.result[l] = {}
self.confusion_matrix.set_test(self.test == l)
self.confusion_matrix.set_reference(self.reference == l)
for metric in eval_metrics:
self.result[l][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
nan_for_nonexisting=self.nan_for_nonexisting,
**metric_kwargs)
return self.result
def to_dict(self):
if self.result is None:
self.evaluate()
return self.result
def to_array(self):
"""Return result as numpy array (labels x metrics)."""
if self.result is None:
self.evaluate
result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32)
if isinstance(self.labels, dict):
for i, label in enumerate(self.labels.keys()):
for j, metric in enumerate(result_metrics):
a[i][j] = self.result[self.labels[label]][metric]
else:
for i, label in enumerate(self.labels):
for j, metric in enumerate(result_metrics):
a[i][j] = self.result[label][metric]
return a
def to_pandas(self):
"""Return result as pandas DataFrame."""
a = self.to_array()
if isinstance(self.labels, dict):
labels = list(self.labels.values())
else:
labels = self.labels
result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
return pd.DataFrame(a, index=labels, columns=result_metrics)
class NiftiEvaluator(Evaluator):
def __init__(self, *args, **kwargs):
self.test_nifti = None
self.reference_nifti = None
super(NiftiEvaluator, self).__init__(*args, **kwargs)
def set_test(self, test):
"""Set the test segmentation."""
if test is not None:
self.test_nifti = sitk.ReadImage(test)
super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti))
else:
self.test_nifti = None
super(NiftiEvaluator, self).set_test(test)
def set_reference(self, reference):
"""Set the reference segmentation."""
if reference is not None:
self.reference_nifti = sitk.ReadImage(reference)
super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti))
else:
self.reference_nifti = None
super(NiftiEvaluator, self).set_reference(reference)
def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs):
if voxel_spacing is None:
voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1]
metric_kwargs["voxel_spacing"] = voxel_spacing
return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs)
def aggregate_scores(test_ref_pairs,
evaluator=NiftiEvaluator,
labels=None,
nanmean=True,
json_output_file=None,
json_name="",
json_description="",
json_author="Fabian",
json_task="",
**metric_kwargs):
"""
test = predicted image
:param test_ref_pairs:
:param evaluator:
:param labels: must be a dict of int-> str or a list of int
:param nanmean:
:param json_output_file:
:param json_name:
:param json_description:
:param json_author:
:param json_task:
:param metric_kwargs:
:return:
"""
if type(evaluator) == type:
evaluator = evaluator()
if labels is not None:
evaluator.set_labels(labels)
all_scores = {}
all_scores["all"] = []
all_scores["mean"] = {}
for i, (test, ref) in enumerate(test_ref_pairs):
# evaluate
evaluator.set_test(test)
evaluator.set_reference(ref)
if evaluator.labels is None:
evaluator.construct_labels()
current_scores = evaluator.evaluate(**metric_kwargs)
if type(test) == str:
current_scores["test"] = test
if type(ref) == str:
current_scores["reference"] = ref
all_scores["all"].append(current_scores)
# append score list for mean
for label, score_dict in current_scores.items():
if label in ("test", "reference"):
continue
if label not in all_scores["mean"]:
all_scores["mean"][label] = {}
for score, value in score_dict.items():
if score not in all_scores["mean"][label]:
all_scores["mean"][label][score] = []
all_scores["mean"][label][score].append(value)
for label in all_scores["mean"]:
for score in all_scores["mean"][label]:
if nanmean:
all_scores["mean"][label][score] = float(np.nanmean(all_scores["mean"][label][score]))
else:
all_scores["mean"][label][score] = float(np.mean(all_scores["mean"][label][score]))
# save to file if desired
# we create a hopefully unique id by hashing the entire output dictionary
if json_output_file is not None:
if type(json_output_file) == str:
json_output_file = open(json_output_file, "w")
json_dict = {}
json_dict["name"] = json_name
json_dict["description"] = json_description
timestamp = datetime.today()
json_dict["timestamp"] = str(timestamp)
json_dict["task"] = json_task
json_dict["author"] = json_author
json_dict["results"] = all_scores
json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
json_output_file.close()
return all_scores
def aggregate_scores_for_experiment(score_file,
labels=None,
metrics=Evaluator.default_metrics,
nanmean=True,
json_output_file=None,
json_name="",
json_description="",
json_author="Fabian",
json_task=""):
scores = np.load(score_file)
scores_mean = scores.mean(0)
if labels is None:
labels = list(map(str, range(scores.shape[1])))
results = []
results_mean = {}
for i in range(scores.shape[0]):
results.append({})
for l, label in enumerate(labels):
results[-1][label] = {}
results_mean[label] = {}
for m, metric in enumerate(metrics):
results[-1][label][metric] = float(scores[i][l][m])
results_mean[label][metric] = float(scores_mean[l][m])
json_dict = {}
json_dict["name"] = json_name
json_dict["description"] = json_description
timestamp = datetime.today()
json_dict["timestamp"] = str(timestamp)
json_dict["task"] = json_task
json_dict["author"] = json_author
json_dict["results"] = {"all": results, "mean": results_mean}
json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
if json_output_file is not None:
json_output_file = open(json_output_file, "w")
json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
json_output_file.close()
return json_dict
================================================
FILE: evaluation/metrics.py
================================================
import numpy as np
from medpy import metric
def assert_shape(test, reference):
assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
test.shape, reference.shape)
class ConfusionMatrix:
def __init__(self, test=None, reference=None):
self.tp = None
self.fp = None
self.tn = None
self.fn = None
self.size = None
self.reference_empty = None
self.reference_full = None
self.test_empty = None
self.test_full = None
self.set_reference(reference)
self.set_test(test)
def set_test(self, test):
self.test = test
self.reset()
def set_reference(self, reference):
self.reference = reference
self.reset()
def reset(self):
self.tp = None
self.fp = None
self.tn = None
self.fn = None
self.size = None
self.test_empty = None
self.test_full = None
self.reference_empty = None
self.reference_full = None
def compute(self):
if self.test is None or self.reference is None:
raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")
assert_shape(self.test, self.reference)
self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
self.size = int(np.product(self.reference.shape))
self.test_empty = not np.any(self.test)
self.test_full = np.all(self.test)
self.reference_empty = not np.any(self.reference)
self.reference_full = np.all(self.reference)
def get_matrix(self):
for entry in (self.tp, self.fp, self.tn, self.fn):
if entry is None:
self.compute()
break
return self.tp, self.fp, self.tn, self.fn
def get_size(self):
if self.size is None:
self.compute()
return self.size
def get_existence(self):
for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
if case is None:
self.compute()
break
return self.test_empty, self.test_full, self.reference_empty, self.reference_full
def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""2TP / (2TP + FP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty and reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(2. * tp / (2 * tp + fp + fn))
def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty and reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fp + fn))
def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FP)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fp))
def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fn))
def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FN)"""
return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)
def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FP)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tn / (tn + fp))
def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
"""(TP + TN) / (TP + FP + FN + TN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return float((tp + tn) / (tp + fp + tn + fn))
def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):
"""(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""
precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)
recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)
return (1 + beta*beta) * precision_ * recall_ /\
((beta*beta * precision_) + recall_)
def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FP / (FP + TN)"""
return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)
def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FN / (TN + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(fn / (fn + tn))
def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FN / (TP + FN)"""
return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)
def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FP)"""
return specificity(test, reference, confusion_matrix, nan_for_nonexisting)
def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FP / (TP + FP)"""
return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)
def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FN)"""
return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)
def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TP + FP"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tp + fp
def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TN + FN"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tn + fn
def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TP + FN"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tp + fn
def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TN + FP"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tn + fp
def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.hd(test, reference, voxel_spacing, connectivity)
def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.hd95(test, reference, voxel_spacing, connectivity)
def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.asd(test, reference, voxel_spacing, connectivity)
def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.assd(test, reference, voxel_spacing, connectivity)
ALL_METRICS = {
"False Positive Rate": false_positive_rate,
"Dice": dice,
"Jaccard": jaccard,
"Hausdorff Distance": hausdorff_distance,
"Hausdorff Distance 95": hausdorff_distance_95,
"Precision": precision,
"Recall": recall,
"Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,
"Avg. Surface Distance": avg_surface_distance,
"Accuracy": accuracy,
"False Omission Rate": false_omission_rate,
"Negative Predictive Value": negative_predictive_value,
"False Negative Rate": false_negative_rate,
"True Negative Rate": true_negative_rate,
"False Discovery Rate": false_discovery_rate,
"Total Positives Test": total_positives_test,
"Total Negatives Test": total_negatives_test,
"Total Positives Reference": total_positives_reference,
"total Negatives Reference": total_negatives_reference
}
================================================
FILE: evaluation/readme.md
================================================
# Evaluation Suite
### Metrics
All metrics can be used either by passing test and reference segmentations as
parameters or by passing a `ConfusionMatrix` object. The latter is useful when many
metrics need to be computed, because the relevant computations are only done once.
All metrics assume binary segmentation inputs.
`ConfusionMatrix` has two important methods: `.get_matrix()`, which returns 4 ints for true positives, false positives, true negatives and false negatives, and
`.get_existence()`, which returns 4 bools, indicating whether test and reference
segmentations are all ones or all zeros. The latter is used when you specify
`nan_for_nonexisting=True` in metric calls to return NaN instead of 0 when the result
is undefined, i.e. would require dividing by 0.
### Evaluator
The `Evaluator` is a class that holds one test and one reference segmentation at a time that can contain multiple labels (one-hot encoding is not supported). It also holds a labels attribute than can either be a list of ints (or tuples of ints) or a dictionary
that maps from ints (or tuples of ints) to label names. A typical labels dictionary
could look like this:
```python
labels = {
1: "Edema",
2: "Enhancing Tumor",
3: "Necrosis",
(1, 2, 3): "Whole Tumor"
}
```
Labels in a tuple will be joined. If no labels are set, they will be automatically constructed from the unique entries in the segmentations upon evaluation. The Evaluator has both a regular set of metrics
that will always be computed and a set of advanced metrics that will only be computed
if `.evaluate(advanced=True)` is passed. The `.evaluate()` method is designed to
look for metric definitions in the current frame, so when you work in an interactive shell and redefine something there (e.g. for testing purposes), the newly defined metric will be used. You can also pass test and reference segmentations directly to evaluate to save calls to `.set_test()` and `.set_reference()`. `.evaluate()` will return a result dictionary and also save it in the `.result` attribute, so you can call `.to_array()` (numpy) or `.to_pandas()` (pandas) later. The resulting shape will be (labels x metrics). `.evaluate()` also takes additional `**metric_kwargs` that will be passed to each metric call.
### NiftiEvaluator
`NiftiEvaluator` redefines the `.set_test()` and `.set_reference()` methods of the `Evaluator` to take path strings instead of arrays. It will read the NIfTI files using SimpleITK, save the SimpleITK images in the `.test_nifti` and `.reference_nifti` attributes and set the arrays as test and reference segmentations. `.evaluate()` has an additional parameter `voxel_spacing`, which should be an iterable of floats. If the parameter is None, the spacing will be automatically read from the SimpleITK images. If you manually read the spacing from SimpleITK images, note that you have to reverse the ordering, because SimpleITK will return (z,y,x) ordering while we expect (x,y,z).
### Evaluating multiple segmentations
If you want to evaluate multiple test/reference pairs and get aggregate statistics, use the `aggregate_scores()` function. It expects an iterable of test/reference pairs and an evaluator (instance or type, will automatically initialize if necessary), which is the `NiftiEvaluator` by default. Test and reference will be set via `.set_test()` and `.set_reference()`, so make sure you're passing the right type for the evaluator. The method will return a dictionary that contains a list of all separate results as well as their mean:
```python
results = {
"all": [
{
"Label": {
"Metric": float,
"Metric": float,
...
},
"Label": {
...
},
...
},
{
"Label": ...,
"Label": ...,
...
},
...
],
"mean": {
"Label": ...
"Label": ...,
...
}
}
```
`nanmean=True` will use `np.nanmean` instead of `np.mean`. It should be easy to adjust the code to compute arbitrary statistics, but at the moment only mean is supported. If you specify a `json_output_file`, a json file will be written that contains the result dictionary as well as additional information you can specify using the other `json_*` parameters:
```python
json = {
"name": json_name, # experiment name, not yours
"description": json_description, # a longer description so you know what you did
"timestamp": "YYYY-MM-DD hh:mm:ss.ffffff", # automatically generated
"task": json_task # the decathlon task
"author": json_author # probably Fabian :)
"results": ... # the above dictionary
"id": 001122334455 # hash of other entries as unique id
}
```
`labels` is passed to the evaluator and `**metric_kwargs` is passed to all `.evaluate()` calls.
================================================
FILE: experiments/UNetExperiment.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from datasets.two_dim.NumpyDataLoader import NumpyDataSet
from trixi.experiment.pytorchexperiment import PytorchExperiment
from networks.UNET import UNet
from loss_functions.dice_loss import SoftDiceLoss
class UNetExperiment(PytorchExperiment):
"""
The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
It is optimized to work with the provided NumpyDataLoader.
The basic life cycle of a UnetExperiment is the same s PytorchExperiment:
setup()
(--> Automatically restore values if a previous checkpoint is given)
prepare()
for epoch in n_epochs:
train()
validate()
(--> save current checkpoint)
end()
"""
def setup(self):
pkl_dir = self.config.split_dir
with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
splits = pickle.load(f)
tr_keys = splits[self.config.fold]['train']
val_keys = splits[self.config.fold]['val']
test_keys = splits[self.config.fold]['test']
self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")
self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=tr_keys)
self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=val_keys, mode="val", do_reshuffle=False)
self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=test_keys, mode="test", do_reshuffle=False)
self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)
self.model.to(self.device)
# We use a combination of DICE-loss and CE-Loss in this example.
# This proved good in the medical segmentation decathlon.
self.dice_loss = SoftDiceLoss(batch_dice=True) # Softmax for DICE Loss!
self.ce_loss = torch.nn.CrossEntropyLoss() # No softmax for CE Loss -> is implemented in torch!
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')
# If directory for checkpoint is provided, we load it.
if self.config.do_load_checkpoint:
if self.config.checkpoint_dir == '':
print('checkpoint_dir is empty, please provide directory to load checkpoint.')
else:
self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model",))
self.save_checkpoint(name="checkpoint_start")
self.elog.print('Experiment set up.')
def train(self, epoch):
self.elog.print('=====TRAIN=====')
self.model.train()
data = None
batch_counter = 0
for data_batch in self.train_data_loader:
self.optimizer.zero_grad()
# Shape of data_batch = [1, b, c, w, h]
# Desired shape = [b, c, w, h]
# Move data and target to the GPU
data = data_batch['data'][0].float().to(self.device)
target = data_batch['seg'][0].long().to(self.device)
pred = self.model(data)
pred_softmax = F.softmax(pred, dim=1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
# loss = self.ce_loss(pred, target.squeeze())
loss.backward()
self.optimizer.step()
# Some logging and plotting
if (batch_counter % self.config.plot_freq) == 0:
self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(self._epoch_idx, loss))
self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))
self.clog.show_image_grid(data.float().cpu(), name="data", normalize=True, scale_each=True, n_iter=epoch)
self.clog.show_image_grid(target.float().cpu(), name="mask", title="Mask", n_iter=epoch)
self.clog.show_image_grid(pred.cpu()[:, 1:2, ], name="unt", normalize=True, scale_each=True, n_iter=epoch)
batch_counter += 1
assert data is not None, 'data is None. Please check if your dataloader works properly'
def validate(self, epoch):
self.elog.print('VALIDATE')
self.model.eval()
data = None
loss_list = []
with torch.no_grad():
for data_batch in self.val_data_loader:
data = data_batch['data'][0].float().to(self.device)
target = data_batch['seg'][0].long().to(self.device)
pred = self.model(data)
pred_softmax = F.softmax(pred, dim=1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
loss_list.append(loss.item())
assert data is not None, 'data is None. Please check if your dataloader works properly'
self.scheduler.step(np.mean(loss_list))
self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))
self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)
self.clog.show_image_grid(data.float().cpu(), name="data_val", normalize=True, scale_each=True, n_iter=epoch)
self.clog.show_image_grid(target.float().cpu(), name="mask_val", title="Mask", n_iter=epoch)
self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ], name="unt_val", normalize=True, scale_each=True, n_iter=epoch)
def test(self):
from evaluation.evaluator import aggregate_scores, Evaluator
from collections import defaultdict
self.elog.print('=====TEST=====')
self.model.eval()
pred_dict = defaultdict(list)
gt_dict = defaultdict(list)
batch_counter = 0
with torch.no_grad():
for data_batch in self.test_data_loader:
print('testing...', batch_counter)
batch_counter += 1
# Get data_batches
mr_data = data_batch['data'][0].float().to(self.device)
mr_target = data_batch['seg'][0].float().to(self.device)
pred = self.model(mr_data)
pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)
fnames = data_batch['fnames']
for i, fname in enumerate(fnames):
pred_dict[fname[0]].append(pred_argmax[i].detach().cpu().numpy())
gt_dict[fname[0]].append(mr_target[i].detach().cpu().numpy())
test_ref_list = []
for key in pred_dict.keys():
test_ref_list.append((np.stack(pred_dict[key]), np.stack(gt_dict[key])))
scores = aggregate_scores(test_ref_list, evaluator=Evaluator, json_author=self.config.author, json_task=self.config.name, json_name=self.config.name,
json_output_file=self.elog.work_dir + "/{}_".format(self.config.author) + self.config.name + '.json')
print("Scores:\n", scores)
def segment_single_image(self, data):
self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)
self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")
# a model must be present and loaded in here
if self.config.model_dir == '':
print('model_dir is empty, please provide directory to load checkpoint.')
else:
self.load_checkpoint(name=self.config.model_dir, save_types=("model",))
self.elog.print("=====SEGMENT_SINGLE_IMAGE=====")
self.model.eval()
self.model.to(self.device)
# Desired shape = [b, c, w, h]
# split into even chunks (lets use size)
with torch.no_grad():
######
# When working entirely on CPU and in memory, the following lines replace the split/concat method
# mr_data = data.float().to(self.device)
# pred = self.model(mr_data)
# pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)
######
######
# for CUDA (also works on CPU) split into batches
blocksize = self.config.batch_size
# number_of_elements = round(data.shape[0]/blocksize+0.5) # make blocks large enough to not lose any slices
chunks = [data[i:i+blocksize, ::, ::, ::] for i in range(0, data.shape[0], blocksize)]
pred_list = []
for data_batch in chunks:
mr_data = data_batch.float().to(self.device)
pred_dict = self.model(mr_data)
pred_list.append(pred_dict.cpu())
pred = torch.Tensor(np.concatenate(pred_list))
pred_argmax = torch.argmax(pred, dim=1, keepdim=True)
# detach result and put it back to cpu so that we can work with, create a numpy array
result = pred_argmax.short().detach().cpu().numpy()
return result
================================================
FILE: experiments/UNetExperiment3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from collections import OrderedDict
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datasets.three_dim.NumpyDataLoader import NumpyDataSet
from trixi.experiment.pytorchexperiment import PytorchExperiment
from networks.RecursiveUNet3D import UNet3D
from loss_functions.dice_loss import SoftDiceLoss, DC_and_CE_loss
class UNetExperiment3D(PytorchExperiment):
"""
The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
It is optimized to work with the provided NumpyDataLoader.
The basic life cycle of a UnetExperiment is the same s PytorchExperiment:
setup()
(--> Automatically restore values if a previous checkpoint is given)
prepare()
for epoch in n_epochs:
train()
validate()
(--> save current checkpoint)
end()
"""
def setup(self):
pkl_dir = self.config.split_dir
with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
splits = pickle.load(f)
tr_keys = splits[self.config.fold]['train']
val_keys = splits[self.config.fold]['val']
test_keys = splits[self.config.fold]['test']
self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")
self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=tr_keys)
self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=val_keys, mode="val", do_reshuffle=False)
self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
keys=test_keys, mode="test", do_reshuffle=False)
self.model = UNet3D(num_classes=3, in_channels=1)
self.model.to(self.device)
# We use a combination of DICE-loss and CE-Loss in this example.
# This proved good in the medical segmentation decathlon.
self.loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, OrderedDict())
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')
# If directory for checkpoint is provided, we load it.
if self.config.do_load_checkpoint:
if self.config.checkpoint_dir == '':
print('checkpoint_dir is empty, please provide directory to load checkpoint.')
else:
self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model",))
self.save_checkpoint(name="checkpoint_start")
self.elog.print('Experiment set up.')
def train(self, epoch):
self.elog.print('=====TRAIN=====')
self.model.train()
batch_counter = 0
for data_batch in self.train_data_loader:
self.optimizer.zero_grad()
# Shape of data_batch = [1, b, c, w, h]
# Desired shape = [b, c, w, h]
# Move data and target to the GPU
data = data_batch['data'][0].float().to(self.device)
target = data_batch['seg'][0].long().to(self.device)
pred = self.model(data)
loss = self.loss(pred, target.squeeze())
# loss = self.ce_loss(pred, target.squeeze())
loss.backward()
self.optimizer.step()
# Some logging and plotting
if (batch_counter % self.config.plot_freq) == 0:
self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, loss))
self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))
self.clog.show_image_grid(data[:,:,30].float(), name="data", normalize=True, scale_each=True, n_iter=epoch)
self.clog.show_image_grid(target[:,:,30].float(), name="mask", title="Mask", n_iter=epoch)
self.clog.show_image_grid(torch.argmax(pred.cpu(), dim=1, keepdim=True)[:,:,30], name="unt_argmax", title="Unet", n_iter=epoch)
batch_counter += 1
def validate(self, epoch):
if epoch % 5 != 0:
return
self.elog.print('VALIDATE')
self.model.eval()
data = None
loss_list = []
with torch.no_grad():
for data_batch in self.val_data_loader:
data = data_batch['data'][0].float().to(self.device)
target = data_batch['seg'][0].long().to(self.device)
pred = self.model(data)
loss = self.loss(pred, target.squeeze())
loss_list.append(loss.item())
assert data is not None, 'data is None. Please check if your dataloader works properly'
self.scheduler.step(np.mean(loss_list))
self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))
self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)
self.clog.show_image_grid(data[:,:,30].float(), name="data_val", normalize=True, scale_each=True, n_iter=epoch)
self.clog.show_image_grid(target[:,:,30].float(), name="mask_val", title="Mask", n_iter=epoch)
self.clog.show_image_grid(torch.argmax(pred.data.cpu()[:,:,30], dim=1, keepdim=True), name="unt_argmax_val", title="Unet", n_iter=epoch)
def test(self):
# TODO
print('TODO: Implement your test() method here')
================================================
FILE: experiments/__init__.py
================================================
================================================
FILE: loss_functions/ND_Crossentropy.py
================================================
import torch
class CrossentropyND(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
return super(CrossentropyND, self).forward(inp, target)
================================================
FILE: loss_functions/__init__.py
================================================
================================================
FILE: loss_functions/dice_loss.py
================================================
import torch
import numpy as np
from loss_functions.ND_Crossentropy import CrossentropyND
from loss_functions.topk_loss import TopKLoss
from torch import nn
def softmax_helper(x):
rpt = [1 for _ in range(len(x.size()))]
rpt[1] = x.size(1)
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
e_x = torch.exp(x - x_max)
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
def get_tp_fp_fn(net_output, gt, axes=None, mask=None):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask:
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
def sum_tensor(inp, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
inp = inp.sum(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
inp = inp.sum(int(ax))
return inp
def mean_tensor(inp, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
inp = inp.mean(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
inp = inp.mean(int(ax))
return inp
class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None):
"""
hahaa no documentation for you today
:param smooth:
:param apply_nonlin:
:param batch_dice:
:param do_bg:
:param smooth_in_nom:
:param background_weight:
:param rebalance_weights:
"""
super(SoftDiceLoss, self).__init__()
if not do_bg:
assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy"
self.rebalance_weights = rebalance_weights
self.background_weight = background_weight
if smooth_in_nom:
self.smooth_in_nom = smooth
else:
self.smooth_in_nom = 0
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.y_onehot = None
def forward(self, x, y):
with torch.no_grad():
y = y.long()
shp_x = x.shape
shp_y = y.shape
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
if len(shp_x) != len(shp_y):
y = y.view((shp_y[0], 1, *shp_y[1:]))
# now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively
y_onehot = torch.zeros(shp_x)
if x.device.type == "cuda":
y_onehot = y_onehot.cuda(x.device.index)
y_onehot.scatter_(1, y, 1)
if not self.do_bg:
x = x[:, 1:]
y_onehot = y_onehot[:, 1:]
if not self.batch_dice:
if self.background_weight != 1 or (self.rebalance_weights is not None):
raise NotImplementedError("nah son")
l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom)
else:
l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom,
background_weight=self.background_weight,
rebalance_weights=self.rebalance_weights)
return l
def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1):
axes = tuple([0] + list(range(2, len(net_output.size()))))
intersect = sum_tensor(net_output * gt, axes, keepdim=False)
denom = sum_tensor(net_output + gt, axes, keepdim=False)
weights = torch.ones(intersect.shape)
weights[0] = background_weight
if net_output.device.type == "cuda":
weights = weights.cuda(net_output.device.index)
result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean()
return result
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None):
if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False
axes = tuple([0] + list(range(2, len(net_output.size()))))
tp = sum_tensor(net_output * gt, axes, keepdim=False)
fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
weights = torch.ones(tp.shape)
weights[0] = background_weight
if net_output.device.type == "cuda":
weights = weights.cuda(net_output.device.index)
if rebalance_weights is not None:
rebalance_weights = torch.from_numpy(rebalance_weights).float()
if net_output.device.type == "cuda":
rebalance_weights = rebalance_weights.cuda(net_output.device.index)
tp = tp * rebalance_weights
fn = fn * rebalance_weights
result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean()
return result
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):
axes = tuple(range(2, len(net_output.size())))
intersect = sum_tensor(net_output * gt, axes, keepdim=False)
denom = sum_tensor(net_output + gt, axes, keepdim=False)
result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean()
return result
class MultipleOutputLoss(nn.Module):
def __init__(self, loss, weight_factors=None):
"""
use this if you have several outputs that should predict the same y
:param loss:
:param weight_factors:
"""
super(MultipleOutputLoss, self).__init__()
self.weight_factors = weight_factors
self.loss = loss
def forward(self, x, y):
assert isinstance(x, (tuple, list)), "x must be either tuple or list"
if self.weight_factors is None:
weights = [1] * len(x)
else:
weights = self.weight_factors
l = weights[0] * self.loss(x[0], y)
for i in range(1, len(x)):
l += weights[i] * self.loss(x[i], y)
return l
class DC_and_CE_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
super(DC_and_CE_loss, self).__init__()
self.aggregate = aggregate
self.ce = CrossentropyND(**ce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def forward(self, net_output, target):
dc_loss = self.dc(net_output, target)
ce_loss = self.ce(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
class DC_and_topk_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
super(DC_and_topk_loss, self).__init__()
self.aggregate = aggregate
self.ce = TopKLoss(**ce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def forward(self, net_output, target):
dc_loss = self.dc(net_output, target)
ce_loss = self.ce(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later?)
return result
class CrossentropyWithLossMask(nn.CrossEntropyLoss):
def __init__(self, k=None):
"""
This implementation ignores weight, ignore_index (use loss mask!) and reduction!
:param k:
"""
super(CrossentropyWithLossMask, self).__init__(weight=None, ignore_index=-100, reduction='none')
self.k = k
def forward(self, inp, target, loss_mask=None):
target = target.long()
inp = inp.float()
if loss_mask is not None:
loss_mask = loss_mask.float()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
if not inp.is_contiguous():
inp = inp.contiguous()
inp = inp.view(target.shape[0], -1, num_classes)
target = target.view(target.shape[0], -1)
if loss_mask is not None:
loss_mask = loss_mask.view(target.shape[0], -1)
if self.k is not None:
if loss_mask is not None:
num_sel = torch.stack(tuple([i.sum() / self.k for i in torch.unbind(loss_mask, 0)]), 0).long()
loss = torch.stack(tuple([
torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()],
num_sel[i], sorted=False)[0].mean()
for i in range(target.shape[0])
])
)
else:
num_sel = [np.prod(inp.shape[2:]) / self.k] * inp.shape[0]
loss = torch.stack(tuple([
torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i]),
num_sel[i], sorted=False)[0].mean()
for i in range(target.shape[0])
])
)
else:
if loss_mask is not None:
loss = torch.stack(tuple([
super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()].mean()
for i in range(target.shape[0])
])
)
else:
loss = torch.stack(tuple([
super(CrossentropyWithLossMask, self).forward(inp[i], target[i]).mean()
for i in range(target.shape[0])
])
)
loss = loss.mean()
return loss
================================================
FILE: loss_functions/topk_loss.py
================================================
import numpy as np
import torch
from loss_functions.ND_Crossentropy import CrossentropyND
class TopKLoss(CrossentropyND):
"""
Network has to have NO LINEARITY!
"""
def __init__(self, weight=None, ignore_index=-100, k=10):
self.k = k
super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)
def forward(self, inp, target):
target = target[:, 0].long()
res = super(TopKLoss, self).forward(inp, target)
num_voxels = np.prod(res.shape)
res, _ = torch.topk(res.view((-1,)), int(num_voxels // self.k), sorted=False)
return res.mean()
================================================
FILE: networks/RecursiveUNet.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Defines the Unet.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
# recursive implementation of Unet
import torch
from torch import nn
class UNet(nn.Module):
def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):
# norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UNet, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
for i in range(1, num_downs):
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
out_channels=initial_filter_size * 2 ** (num_downs-i),
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
outermost=True)
self.model = unet_block
def forward(self, x):
return self.model(x)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
# downconv
pool = nn.MaxPool2d(2, stride=2)
conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
# upconv
conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)
if outermost:
final = nn.Conv2d(out_channels, num_classes, kernel_size=1)
down = [conv1, conv2]
up = [conv3, conv4, final]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(in_channels*2, in_channels,
kernel_size=2, stride=2)
model = [pool, conv1, conv2, upconv]
else:
upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)
down = [pool, conv1, conv2]
up = [conv3, conv4, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
@staticmethod
def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
norm_layer(out_channels),
nn.LeakyReLU(inplace=True))
return layer
@staticmethod
def expand(in_channels, out_channels, kernel_size=3):
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.LeakyReLU(inplace=True),
)
return layer
@staticmethod
def center_crop(layer, target_width, target_height):
batch_size, n_channels, layer_width, layer_height = layer.size()
xy1 = (layer_width - target_width) // 2
xy2 = (layer_height - target_height) // 2
return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]
def forward(self, x):
if self.outermost:
return self.model(x)
else:
crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])
return torch.cat([x, crop], 1)
================================================
FILE: networks/RecursiveUNet3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Defines the Unet.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
# recursive implementation of Unet
import torch
from torch import nn
class UNet3D(nn.Module):
def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=3, norm_layer=nn.InstanceNorm3d):
# norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UNet3D, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
for i in range(1, num_downs):
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
out_channels=initial_filter_size * 2 ** (num_downs-i),
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
outermost=True)
self.model = unet_block
def forward(self, x):
return self.model(x)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm3d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
# downconv
pool = nn.MaxPool3d(2, stride=2)
conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
# upconv
conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)
if outermost:
final = nn.Conv3d(out_channels, num_classes, kernel_size=1)
down = [conv1, conv2]
up = [conv3, conv4, final]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose3d(in_channels*2, in_channels,
kernel_size=2, stride=2)
model = [pool, conv1, conv2, upconv]
else:
upconv = nn.ConvTranspose3d(in_channels*2, in_channels, kernel_size=2, stride=2)
down = [pool, conv1, conv2]
up = [conv3, conv4, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
@staticmethod
def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm3d):
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),
norm_layer(out_channels),
nn.LeakyReLU(inplace=True))
return layer
@staticmethod
def expand(in_channels, out_channels, kernel_size=3):
layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),
nn.LeakyReLU(inplace=True),
)
return layer
@staticmethod
def center_crop(layer, target_depth, target_width, target_height):
batch_size, n_channels, layer_depth, layer_width, layer_height = layer.size()
xy0 = (layer_depth - target_depth) // 2
xy1 = (layer_width - target_width) // 2
xy2 = (layer_height - target_height) // 2
return layer[:, :, xy0:(xy0 + target_depth), xy1:(xy1 + target_width), xy2:(xy2 + target_height)]
def forward(self, x):
if self.outermost:
return self.model(x)
else:
crop = self.center_crop(self.model(x), x.size()[2], x.size()[3], x.size()[4])
return torch.cat([x, crop], 1)
================================================
FILE: networks/UNET.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, num_classes, in_channels=1, initial_filter_size=64, kernel_size=3, do_instancenorm=True):
super().__init__()
self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm)
self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, instancenorm=do_instancenorm)
self.pool = nn.MaxPool2d(2, stride=2)
self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)
self.contr_2_2 = self.contract(initial_filter_size*2, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)
# self.pool2 = nn.MaxPool2d(2, stride=2)
self.contr_3_1 = self.contract(initial_filter_size*2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)
self.contr_3_2 = self.contract(initial_filter_size*2**2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)
# self.pool3 = nn.MaxPool2d(2, stride=2)
self.contr_4_1 = self.contract(initial_filter_size*2**2, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)
self.contr_4_2 = self.contract(initial_filter_size*2**3, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)
# self.pool4 = nn.MaxPool2d(2, stride=2)
self.center = nn.Sequential(
nn.Conv2d(initial_filter_size*2**3, initial_filter_size*2**4, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(initial_filter_size*2**4, initial_filter_size*2**4, 3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(initial_filter_size*2**4, initial_filter_size*2**3, 2, stride=2),
nn.ReLU(inplace=True),
)
self.expand_4_1 = self.expand(initial_filter_size*2**4, initial_filter_size*2**3)
self.expand_4_2 = self.expand(initial_filter_size*2**3, initial_filter_size*2**3)
self.upscale4 = nn.ConvTranspose2d(initial_filter_size*2**3, initial_filter_size*2**2, kernel_size=2, stride=2)
self.expand_3_1 = self.expand(initial_filter_size*2**3, initial_filter_size*2**2)
self.expand_3_2 = self.expand(initial_filter_size*2**2, initial_filter_size*2**2)
self.upscale3 = nn.ConvTranspose2d(initial_filter_size*2**2, initial_filter_size*2, 2, stride=2)
self.expand_2_1 = self.expand(initial_filter_size*2**2, initial_filter_size*2)
self.expand_2_2 = self.expand(initial_filter_size*2, initial_filter_size*2)
self.upscale2 = nn.ConvTranspose2d(initial_filter_size*2, initial_filter_size, 2, stride=2)
self.expand_1_1 = self.expand(initial_filter_size*2, initial_filter_size)
self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size)
# Output layer for segmentation
self.final = nn.Conv2d(initial_filter_size, num_classes, kernel_size=1) # kernel size for final layer = 1, see paper
self.softmax = torch.nn.Softmax2d()
# Output layer for "autoencoder-mode"
self.output_reconstruction_map = nn.Conv2d(initial_filter_size, out_channels=1, kernel_size=1)
@staticmethod
def contract(in_channels, out_channels, kernel_size=3, instancenorm=True):
if instancenorm:
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(inplace=True))
else:
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.LeakyReLU(inplace=True))
return layer
@staticmethod
def expand(in_channels, out_channels, kernel_size=3):
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.LeakyReLU(inplace=True),
)
return layer
@staticmethod
def center_crop(layer, target_width, target_height):
batch_size, n_channels, layer_width, layer_height = layer.size()
xy1 = (layer_width - target_width) // 2
xy2 = (layer_height - target_height) // 2
return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]
def forward(self, x, enable_concat=True, print_layer_shapes=False):
concat_weight = 1
if not enable_concat:
concat_weight = 0
contr_1 = self.contr_1_2(self.contr_1_1(x))
pool = self.pool(contr_1)
contr_2 = self.contr_2_2(self.contr_2_1(pool))
pool = self.pool(contr_2)
contr_3 = self.contr_3_2(self.contr_3_1(pool))
pool = self.pool(contr_3)
contr_4 = self.contr_4_2(self.contr_4_1(pool))
pool = self.pool(contr_4)
center = self.center(pool)
crop = self.center_crop(contr_4, center.size()[2], center.size()[3])
concat = torch.cat([center, crop*concat_weight], 1)
expand = self.expand_4_2(self.expand_4_1(concat))
upscale = self.upscale4(expand)
crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3])
concat = torch.cat([upscale, crop*concat_weight], 1)
expand = self.expand_3_2(self.expand_3_1(concat))
upscale = self.upscale3(expand)
crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3])
concat = torch.cat([upscale, crop*concat_weight], 1)
expand = self.expand_2_2(self.expand_2_1(concat))
upscale = self.upscale2(expand)
crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3])
concat = torch.cat([upscale, crop*concat_weight], 1)
expand = self.expand_1_2(self.expand_1_1(concat))
if enable_concat:
output = self.final(expand)
if not enable_concat:
output = self.output_reconstruction_map(expand)
return output
================================================
FILE: requirements.txt
================================================
googledrivedownloader==0.4
MedPy==0.4.0
torch==1.3.1
torchfile==0.1.0
trixi==0.1.2.1
batchgenerators==0.19.3
# Workaround for scipy issues
scipy==1.1.0
# Workaround for slackclient version issues
slackclient==2.0.0
# Fix compatibility issues
torchvision==0.4.2
================================================
FILE: run_preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from configs.Config_unet import get_config
from datasets.example_dataset.create_splits import create_splits
from datasets.utils import download_dataset
from datasets.example_dataset.preprocessing import preprocess_data
if __name__ == "__main__":
c = get_config()
download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)
print('Preprocessing data. [STARTED]')
preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name))
create_splits(output_dir=c.split_dir, image_dir=c.data_dir)
print('Preprocessing data. [DONE]')
================================================
FILE: run_train_pipeline.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import exists
from configs.Config_unet import get_config
from datasets.example_dataset.create_splits import create_splits
from datasets.utils import download_dataset
from datasets.example_dataset.preprocessing import preprocess_data
from experiments.UNetExperiment import UNetExperiment
if __name__ == "__main__":
c = get_config()
# print("Executing: EPOCHS = {} / LEARNING RATE = {}".format(c.n_epochs, c.learning_rate))
download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)
if not exists(os.path.join(os.path.join(c.data_root_dir, c.dataset_name), 'preprocessed')):
print('Preprocessing data. [STARTED]')
preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name), y_shape=c.patch_size, z_shape=c.patch_size)
create_splits(output_dir=c.split_dir, image_dir=c.data_dir)
print('Preprocessing data. [DONE]')
else:
print('The data has already been preprocessed. It will not be preprocessed again. Delete the folder to enforce it.')
exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
# visdomlogger_kwargs={"auto_start": c.start_visdom},
loggers={
"visdom": ("visdom", {"auto_start": c.start_visdom})
}
)
exp.run()
exp.run_test(setup=False)
================================================
FILE: runner.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from configs.Config_unet_spleen import get_config
import subprocess
if __name__ == "__main__":
c = get_config()
n_epochs = c.n_epochs
learning_rate = c.learning_rate
step = 0
while True:
result = subprocess.run(['python', 'run_train_pipeline.py',
'--n_epochs', '{}'.format(n_epochs),
'--learning_rate', '{}'.format(learning_rate)])
if divmod(step, 2)[1] == 0:
n_epochs = n_epochs + 20
else:
learning_rate = learning_rate / 2
step += 1
================================================
FILE: segment_a_spleen.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from medpy.io import save
from configs.Config_unet_spleen import get_config
from datasets.spleen.preprocessing import preprocess_single_file, postprocess_single_image
from experiments.UNetExperiment import UNetExperiment
def save_single_image(image, image_header, filename):
# medpy.io.save
save(image, filename, image_header)
print('> Resulting Image stored as {}'.format(filename))
if __name__ == "__main__":
c = get_config()
if len(sys.argv) == 1:
print("USAGE:\n\npython {} imagefilename [model_checkpoint [shapesize]]\n\n"
" imagefilename - a filename that stores a nii.gz formatted file.\n"
" model_checkpoint - a checkpoint filename to reload\n"
" shapesize - optional value that defines "
"the size of the shape, default is 64 (not yet used).".format(sys.argv[0]))
filename = "data/Task09_Spleen/imagesTs/spleen_15.nii.gz"
else:
filename = sys.argv[1]
print("Loading and processing file {}".format(filename))
if len(sys.argv) > 2:
c.checkpoint_dir = sys.argv[2]
c.do_load_checkpoint = True
print("Loading model from checkpoint {}".format(c.model_dir))
if len(c.model_dir) == 0 or not os.path.isdir(os.path.split(c.model_dir)[0]):
print("ERROR /!\\: No checkpoint dir is set, please provide in Config file.")
exit()
shapesize = 64
if len(sys.argv) > 3:
shapesize = int(sys.argv[3])
# Get the header in order to preserve voxel dimensions to store the segmented image later on
print('Preprocessing data.')
data, header = preprocess_single_file(filename, y_shape=shapesize, z_shape=shapesize)
print('Setting up model and start segmentation.')
exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals()
)
result = exp.segment_single_image(data)
print('Postprocessing data.')
result = postprocess_single_image(result)
pathname, fname = os.path.split(filename)
destination_filename = pathname+"/segmented_"+fname
print('Saving file to disk: {}'.format(destination_filename))
save_single_image(result, header, destination_filename)
================================================
FILE: train.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
from configs.Config_unet import get_config
from experiments.UNetExperiment import UNetExperiment
if __name__ == "__main__":
c = get_config()
exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
# visdomlogger_kwargs={"auto_start": c.start_visdom},
loggers={
"visdom": ("visdom", {"auto_start": c.start_visdom}),
# "tb": ("tensorboard"),
# "slack": ("slack", {"token": "XXXXXXXX",
# "user_email": "x"})
}
)
exp.run()
exp.run_test(setup=False)
================================================
FILE: train3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from configs.Config_unet import get_config
from experiments.UNetExperiment3D import UNetExperiment3D
if __name__ == "__main__":
c = get_config()
exp = UNetExperiment3D(config=c, name=c.name, n_epochs=c.n_epochs,
seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
# visdomlogger_kwargs={"auto_start": c.start_visdom},
loggers={
"visdom": ("visdom", {"auto_start": c.start_visdom}),
# "tb": ("tensorboard"),
# "slack": ("slack", {"token": "XXXXXXXX",
# "user_email": "x"})
}
)
exp.run()
exp.run_test(setup=False)
================================================
FILE: utilities/__init__.py
================================================
================================================
FILE: utilities/file_and_folder_operations.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
if join:
l = os.path.join
else:
l = lambda x, y: y
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
and (prefix is None or i.startswith(prefix))
and (suffix is None or i.endswith(suffix))]
if sort:
res.sort()
return res
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
if join:
l = os.path.join
else:
l = lambda x, y: y
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
and (prefix is None or i.startswith(prefix))
and (suffix is None or i.endswith(suffix))]
if sort:
res.sort()
return res
def maybe_mkdir_p(directory):
splits = directory.split("/")[1:]
for i in range(0, len(splits)):
if not os.path.isdir(os.path.join("/", *splits[:i+1])):
os.mkdir(os.path.join("/", *splits[:i+1]))
gitextract_9r9wvxz8/
├── .gitignore
├── LICENSE
├── Readme.md
├── __init__.py
├── configs/
│ ├── Config_unet.py
│ ├── Config_unet_spleen.py
│ └── __init__.py
├── datasets/
│ ├── __init__.py
│ ├── data_loader.py
│ ├── example_dataset/
│ │ ├── __init__.py
│ │ ├── create_splits.py
│ │ └── preprocessing.py
│ ├── spleen/
│ │ ├── __init__.py
│ │ ├── create_splits.py
│ │ └── preprocessing.py
│ ├── three_dim/
│ │ ├── NumpyDataLoader.py
│ │ ├── __init__.py
│ │ └── data_augmentation.py
│ ├── two_dim/
│ │ ├── NumpyDataLoader.py
│ │ ├── __init__.py
│ │ └── data_augmentation.py
│ └── utils.py
├── evaluation/
│ ├── __init__.py
│ ├── evaluator.py
│ ├── metrics.py
│ └── readme.md
├── experiments/
│ ├── UNetExperiment.py
│ ├── UNetExperiment3D.py
│ └── __init__.py
├── loss_functions/
│ ├── ND_Crossentropy.py
│ ├── __init__.py
│ ├── dice_loss.py
│ └── topk_loss.py
├── networks/
│ ├── RecursiveUNet.py
│ ├── RecursiveUNet3D.py
│ └── UNET.py
├── requirements.txt
├── run_preprocessing.py
├── run_train_pipeline.py
├── runner.py
├── segment_a_spleen.py
├── train.py
├── train3D.py
└── utilities/
├── __init__.py
└── file_and_folder_operations.py
SYMBOL INDEX (169 symbols across 24 files)
FILE: configs/Config_unet.py
function get_config (line 23) | def get_config():
FILE: configs/Config_unet_spleen.py
function get_config (line 23) | def get_config():
FILE: datasets/data_loader.py
class WrappedDataset (line 22) | class WrappedDataset(Dataset):
method __init__ (line 23) | def __init__(self, dataset, transform):
method __getitem__ (line 31) | def __getitem__(self, index):
method __len__ (line 40) | def __len__(self):
class MultiThreadedDataLoader (line 44) | class MultiThreadedDataLoader(object):
method __init__ (line 45) | def __init__(self, data_loader, transform, num_processes, **kwargs):
method get_worker_init_fn (line 57) | def get_worker_init_fn(self):
method __iter__ (line 63) | def __iter__(self):
method __next__ (line 68) | def __next__(self):
method renew (line 73) | def renew(self):
method restart (line 79) | def restart(self):
method kill_iterator (line 83) | def kill_iterator(self):
FILE: datasets/example_dataset/create_splits.py
function create_splits (line 25) | def create_splits(output_dir, image_dir):
function splits_sanity_check (line 70) | def splits_sanity_check(path):
FILE: datasets/example_dataset/preprocessing.py
function preprocess_data (line 29) | def preprocess_data(root_dir, y_shape=64, z_shape=64):
function preprocess_single_file (line 74) | def preprocess_single_file(image_file):
function postprocess_single_image (line 83) | def postprocess_single_image(image):
FILE: datasets/spleen/create_splits.py
function create_splits (line 25) | def create_splits(output_dir, image_dir):
FILE: datasets/spleen/preprocessing.py
function preprocess_data (line 28) | def preprocess_data(root_dir, y_shape=64, z_shape=64):
function preprocess_single_file (line 68) | def preprocess_single_file(image_file):
function postprocess_single_image (line 87) | def postprocess_single_image(image):
FILE: datasets/three_dim/NumpyDataLoader.py
function load_dataset (line 29) | def load_dataset(base_dir, pattern='*.npy', keys=None):
class NumpyDataSet (line 52) | class NumpyDataSet(object):
method __init__ (line 56) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
method __len__ (line 73) | def __len__(self):
method __iter__ (line 76) | def __iter__(self):
method __next__ (line 82) | def __next__(self):
class NumpyDataLoader (line 86) | class NumpyDataLoader(SlimDataLoaderBase):
method __init__ (line 87) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
method reshuffle (line 112) | def reshuffle(self):
method generate_train_batch (line 117) | def generate_train_batch(self):
method __len__ (line 121) | def __len__(self):
method __getitem__ (line 125) | def __getitem__(self, item):
method get_data_from_array (line 150) | def get_data_from_array(self, open_array):
FILE: datasets/three_dim/data_augmentation.py
function get_transforms (line 24) | def get_transforms(mode="train", target_size=128):
FILE: datasets/two_dim/NumpyDataLoader.py
function load_dataset (line 29) | def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):
class NumpyDataSet (line 52) | class NumpyDataSet(object):
method __init__ (line 56) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
method __len__ (line 73) | def __len__(self):
method __iter__ (line 76) | def __iter__(self):
method __next__ (line 82) | def __next__(self):
class NumpyDataLoader (line 86) | class NumpyDataLoader(SlimDataLoaderBase):
method __init__ (line 87) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
method reshuffle (line 112) | def reshuffle(self):
method generate_train_batch (line 117) | def generate_train_batch(self):
method __len__ (line 121) | def __len__(self):
method __getitem__ (line 125) | def __getitem__(self, item):
method get_data_from_array (line 150) | def get_data_from_array(self, open_array):
FILE: datasets/two_dim/data_augmentation.py
function get_transforms (line 26) | def get_transforms(mode="train", target_size=128):
FILE: datasets/utils.py
function download_dataset (line 25) | def download_dataset(dest_path, dataset, id=''):
FILE: evaluation/evaluator.py
class Evaluator (line 15) | class Evaluator:
method __init__ (line 45) | def __init__(self,
method set_test (line 84) | def set_test(self, test):
method set_reference (line 89) | def set_reference(self, reference):
method set_labels (line 94) | def set_labels(self, labels):
method construct_labels (line 110) | def construct_labels(self):
method set_metrics (line 122) | def set_metrics(self, metrics):
method add_metric (line 132) | def add_metric(self, metric):
method evaluate (line 137) | def evaluate(self, test=None, reference=None, advanced=False, **metric...
method to_dict (line 210) | def to_dict(self):
method to_array (line 216) | def to_array(self):
method to_pandas (line 237) | def to_pandas(self):
class NiftiEvaluator (line 252) | class NiftiEvaluator(Evaluator):
method __init__ (line 254) | def __init__(self, *args, **kwargs):
method set_test (line 260) | def set_test(self, test):
method set_reference (line 270) | def set_reference(self, reference):
method evaluate (line 280) | def evaluate(self, test=None, reference=None, voxel_spacing=None, **me...
function aggregate_scores (line 289) | def aggregate_scores(test_ref_pairs,
function aggregate_scores_for_experiment (line 376) | def aggregate_scores_for_experiment(score_file,
FILE: evaluation/metrics.py
function assert_shape (line 5) | def assert_shape(test, reference):
class ConfusionMatrix (line 11) | class ConfusionMatrix:
method __init__ (line 13) | def __init__(self, test=None, reference=None):
method set_test (line 27) | def set_test(self, test):
method set_reference (line 32) | def set_reference(self, reference):
method reset (line 37) | def reset(self):
method compute (line 49) | def compute(self):
method get_matrix (line 66) | def get_matrix(self):
method get_size (line 75) | def get_size(self):
method get_existence (line 81) | def get_existence(self):
function dice (line 91) | def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonex...
function jaccard (line 109) | def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_no...
function precision (line 127) | def precision(test=None, reference=None, confusion_matrix=None, nan_for_...
function sensitivity (line 145) | def sensitivity(test=None, reference=None, confusion_matrix=None, nan_fo...
function recall (line 163) | def recall(test=None, reference=None, confusion_matrix=None, nan_for_non...
function specificity (line 169) | def specificity(test=None, reference=None, confusion_matrix=None, nan_fo...
function accuracy (line 187) | def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
function fscore (line 198) | def fscore(test=None, reference=None, confusion_matrix=None, nan_for_non...
function false_positive_rate (line 208) | def false_positive_rate(test=None, reference=None, confusion_matrix=None...
function false_omission_rate (line 214) | def false_omission_rate(test=None, reference=None, confusion_matrix=None...
function false_negative_rate (line 232) | def false_negative_rate(test=None, reference=None, confusion_matrix=None...
function true_negative_rate (line 238) | def true_negative_rate(test=None, reference=None, confusion_matrix=None,...
function false_discovery_rate (line 244) | def false_discovery_rate(test=None, reference=None, confusion_matrix=Non...
function negative_predictive_value (line 250) | def negative_predictive_value(test=None, reference=None, confusion_matri...
function total_positives_test (line 256) | def total_positives_test(test=None, reference=None, confusion_matrix=Non...
function total_negatives_test (line 267) | def total_negatives_test(test=None, reference=None, confusion_matrix=Non...
function total_positives_reference (line 278) | def total_positives_reference(test=None, reference=None, confusion_matri...
function total_negatives_reference (line 289) | def total_negatives_reference(test=None, reference=None, confusion_matri...
function hausdorff_distance (line 300) | def hausdorff_distance(test=None, reference=None, confusion_matrix=None,...
function hausdorff_distance_95 (line 318) | def hausdorff_distance_95(test=None, reference=None, confusion_matrix=No...
function avg_surface_distance (line 336) | def avg_surface_distance(test=None, reference=None, confusion_matrix=Non...
function avg_surface_distance_symmetric (line 354) | def avg_surface_distance_symmetric(test=None, reference=None, confusion_...
FILE: experiments/UNetExperiment.py
class UNetExperiment (line 34) | class UNetExperiment(PytorchExperiment):
method setup (line 53) | def setup(self):
method train (line 92) | def train(self, epoch):
method validate (line 131) | def validate(self, epoch):
method test (line 160) | def test(self):
method segment_single_image (line 197) | def segment_single_image(self, data):
FILE: experiments/UNetExperiment3D.py
class UNetExperiment3D (line 34) | class UNetExperiment3D(PytorchExperiment):
method setup (line 53) | def setup(self):
method train (line 92) | def train(self, epoch):
method validate (line 126) | def validate(self, epoch):
method test (line 156) | def test(self):
FILE: loss_functions/ND_Crossentropy.py
class CrossentropyND (line 4) | class CrossentropyND(torch.nn.CrossEntropyLoss):
method forward (line 8) | def forward(self, inp, target):
FILE: loss_functions/dice_loss.py
function softmax_helper (line 8) | def softmax_helper(x):
function get_tp_fp_fn (line 16) | def get_tp_fp_fn(net_output, gt, axes=None, mask=None):
function sum_tensor (line 63) | def sum_tensor(inp, axes, keepdim=False):
function mean_tensor (line 74) | def mean_tensor(inp, axes, keepdim=False):
class SoftDiceLoss (line 85) | class SoftDiceLoss(nn.Module):
method __init__ (line 86) | def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_...
method forward (line 112) | def forward(self, x, y):
function soft_dice_per_batch (line 140) | def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., bac...
function soft_dice_per_batch_2 (line 152) | def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., b...
function soft_dice (line 173) | def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):
class MultipleOutputLoss (line 181) | class MultipleOutputLoss(nn.Module):
method __init__ (line 182) | def __init__(self, loss, weight_factors=None):
method forward (line 192) | def forward(self, x, y):
class DC_and_CE_loss (line 204) | class DC_and_CE_loss(nn.Module):
method __init__ (line 205) | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
method forward (line 211) | def forward(self, net_output, target):
class DC_and_topk_loss (line 221) | class DC_and_topk_loss(nn.Module):
method __init__ (line 222) | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
method forward (line 228) | def forward(self, net_output, target):
class CrossentropyWithLossMask (line 238) | class CrossentropyWithLossMask(nn.CrossEntropyLoss):
method __init__ (line 239) | def __init__(self, k=None):
method forward (line 247) | def forward(self, inp, target, loss_mask=None):
FILE: loss_functions/topk_loss.py
class TopKLoss (line 6) | class TopKLoss(CrossentropyND):
method __init__ (line 11) | def __init__(self, weight=None, ignore_index=-100, k=10):
method forward (line 15) | def forward(self, inp, target):
FILE: networks/RecursiveUNet.py
class UNet (line 28) | class UNet(nn.Module):
method __init__ (line 29) | def __init__(self, num_classes=3, in_channels=1, initial_filter_size=6...
method forward (line 46) | def forward(self, x):
class UnetSkipConnectionBlock (line 53) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 54) | def __init__(self, in_channels=None, out_channels=None, num_classes=1,...
method contract (line 90) | def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.I...
method expand (line 98) | def expand(in_channels, out_channels, kernel_size=3):
method center_crop (line 106) | def center_crop(layer, target_width, target_height):
method forward (line 112) | def forward(self, x):
FILE: networks/RecursiveUNet3D.py
class UNet3D (line 28) | class UNet3D(nn.Module):
method __init__ (line 29) | def __init__(self, num_classes=3, in_channels=1, initial_filter_size=6...
method forward (line 46) | def forward(self, x):
class UnetSkipConnectionBlock (line 53) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 54) | def __init__(self, in_channels=None, out_channels=None, num_classes=1,...
method contract (line 90) | def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.I...
method expand (line 98) | def expand(in_channels, out_channels, kernel_size=3):
method center_crop (line 106) | def center_crop(layer, target_depth, target_width, target_height):
method forward (line 113) | def forward(self, x):
FILE: networks/UNET.py
class UNet (line 22) | class UNet(nn.Module):
method __init__ (line 24) | def __init__(self, num_classes, in_channels=1, initial_filter_size=64,...
method contract (line 75) | def contract(in_channels, out_channels, kernel_size=3, instancenorm=Tr...
method expand (line 88) | def expand(in_channels, out_channels, kernel_size=3):
method center_crop (line 96) | def center_crop(layer, target_width, target_height):
method forward (line 102) | def forward(self, x, enable_concat=True, print_layer_shapes=False):
FILE: segment_a_spleen.py
function save_single_image (line 28) | def save_single_image(image, image_header, filename):
FILE: utilities/file_and_folder_operations.py
function subdirs (line 21) | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
function subfiles (line 34) | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
function maybe_mkdir_p (line 47) | def maybe_mkdir_p(directory):
Condensed preview — 45 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (161K chars).
[
{
"path": ".gitignore",
"chars": 78,
"preview": ".idea\n*.pyc\n.DS_Store\n*.egg-info\n.pytest_cache/*\n\ndata\noutput_experiment\nvenv\n"
},
{
"path": "LICENSE",
"chars": 11356,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "Readme.md",
"chars": 11774,
"preview": "# Basic U-Net example by MIC@DKFZ\nCopyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "configs/Config_unet.py",
"chars": 3047,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "configs/Config_unet_spleen.py",
"chars": 3072,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "configs/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/data_loader.py",
"chars": 2897,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/example_dataset/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/example_dataset/create_splits.py",
"chars": 3543,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/example_dataset/preprocessing.py",
"chars": 2802,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/spleen/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/spleen/create_splits.py",
"chars": 1958,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/spleen/preprocessing.py",
"chars": 3024,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/three_dim/NumpyDataLoader.py",
"chars": 5585,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/three_dim/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/three_dim/data_augmentation.py",
"chars": 2441,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/two_dim/NumpyDataLoader.py",
"chars": 5899,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/two_dim/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/two_dim/data_augmentation.py",
"chars": 2555,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "datasets/utils.py",
"chars": 1486,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "evaluation/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "evaluation/evaluator.py",
"chars": 14632,
"preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport collections\nimport inspect\nimport json\nimport hashlib\nfrom dateti"
},
{
"path": "evaluation/metrics.py",
"chars": 12329,
"preview": "import numpy as np\nfrom medpy import metric\n\n\ndef assert_shape(test, reference):\n\n assert test.shape == reference.sha"
},
{
"path": "evaluation/readme.md",
"chars": 4881,
"preview": "# Evaluation Suite\n\n### Metrics\n\nAll metrics can be used either by passing test and reference segmentations as\nparameter"
},
{
"path": "experiments/UNetExperiment.py",
"chars": 10517,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "experiments/UNetExperiment3D.py",
"chars": 6672,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "experiments/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "loss_functions/ND_Crossentropy.py",
"chars": 609,
"preview": "import torch\n\n\nclass CrossentropyND(torch.nn.CrossEntropyLoss):\n \"\"\"\n Network has to have NO NONLINEARITY!\n \"\"\""
},
{
"path": "loss_functions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "loss_functions/dice_loss.py",
"chars": 11290,
"preview": "import torch\nimport numpy as np\nfrom loss_functions.ND_Crossentropy import CrossentropyND\nfrom loss_functions.topk_loss "
},
{
"path": "loss_functions/topk_loss.py",
"chars": 624,
"preview": "import numpy as np\nimport torch\nfrom loss_functions.ND_Crossentropy import CrossentropyND\n\n\nclass TopKLoss(CrossentropyN"
},
{
"path": "networks/RecursiveUNet.py",
"chars": 5265,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "networks/RecursiveUNet3D.py",
"chars": 5383,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "networks/UNET.py",
"chars": 6825,
"preview": "#!/usr/bin/env python\r\n# -*- coding: utf-8 -*-\r\n#\r\n# Copyright 2017 Division of Medical Image Computing, German Cancer R"
},
{
"path": "requirements.txt",
"chars": 265,
"preview": "googledrivedownloader==0.4\nMedPy==0.4.0\ntorch==1.3.1\ntorchfile==0.1.0\ntrixi==0.1.2.1\nbatchgenerators==0.19.3\n\n# Workarou"
},
{
"path": "run_preprocessing.py",
"chars": 1284,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "run_train_pipeline.py",
"chars": 2208,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "runner.py",
"chars": 1266,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2019 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "segment_a_spleen.py",
"chars": 2995,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "train.py",
"chars": 1527,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "train3D.py",
"chars": 1491,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
},
{
"path": "utilities/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utilities/file_and_folder_operations.py",
"chars": 1723,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
}
]
About this extraction
This page contains the full source code of the MIC-DKFZ/basic_unet_example GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 45 files (149.7 KB), approximately 36.2k tokens, and a symbol index with 169 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.