Full Code of MIC-DKFZ/basic_unet_example for AI

master 063353b31517 cached
45 files
149.7 KB
36.2k tokens
169 symbols
1 requests
Download .txt
Repository: MIC-DKFZ/basic_unet_example
Branch: master
Commit: 063353b31517
Files: 45
Total size: 149.7 KB

Directory structure:
gitextract_9r9wvxz8/

├── .gitignore
├── LICENSE
├── Readme.md
├── __init__.py
├── configs/
│   ├── Config_unet.py
│   ├── Config_unet_spleen.py
│   └── __init__.py
├── datasets/
│   ├── __init__.py
│   ├── data_loader.py
│   ├── example_dataset/
│   │   ├── __init__.py
│   │   ├── create_splits.py
│   │   └── preprocessing.py
│   ├── spleen/
│   │   ├── __init__.py
│   │   ├── create_splits.py
│   │   └── preprocessing.py
│   ├── three_dim/
│   │   ├── NumpyDataLoader.py
│   │   ├── __init__.py
│   │   └── data_augmentation.py
│   ├── two_dim/
│   │   ├── NumpyDataLoader.py
│   │   ├── __init__.py
│   │   └── data_augmentation.py
│   └── utils.py
├── evaluation/
│   ├── __init__.py
│   ├── evaluator.py
│   ├── metrics.py
│   └── readme.md
├── experiments/
│   ├── UNetExperiment.py
│   ├── UNetExperiment3D.py
│   └── __init__.py
├── loss_functions/
│   ├── ND_Crossentropy.py
│   ├── __init__.py
│   ├── dice_loss.py
│   └── topk_loss.py
├── networks/
│   ├── RecursiveUNet.py
│   ├── RecursiveUNet3D.py
│   └── UNET.py
├── requirements.txt
├── run_preprocessing.py
├── run_train_pipeline.py
├── runner.py
├── segment_a_spleen.py
├── train.py
├── train3D.py
└── utilities/
    ├── __init__.py
    └── file_and_folder_operations.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea
*.pyc
.DS_Store
*.egg-info
.pytest_cache/*

data
output_experiment
venv


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "{}"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright {yyyy} {name of copyright owner}

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: Readme.md
================================================
# Basic U-Net example by MIC@DKFZ
Copyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). Please make sure that your usage of this code is in compliance with the code license:
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/MIC-DKFZ/basic_unet_example/blob/master/LICENSE)

This python code is an example project of how to use a U-Net [1] for segmentation on medical images using PyTorch (https://www.pytorch.org).
It was developed at the Division of Medical Image Computing at the German Cancer Research Center (DKFZ).
It is also an example of how to use our other python packages batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) and 
Trixi (https://github.com/MIC-DKFZ/trixi) [2] to suit all our deep learning data augmentation needs.

If you have any questions or issues or you encounter a bug, feel free to contact us, open a GitHub issue or ask the community on Gitter:
[![Gitter](https://badges.gitter.im/basic-Unet/community.svg)](https://gitter.im/basic-Unet/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)

> **WARNING**: This repo was implemented and tested on Linux. We highly recommend using it within a Linux environment. If you use Windows you might experience some issues (see
> section "Errors and how to handle them")

## How to set it up
The example is very easy to use. Just create a new virtual environment in python and install the requirements. 
This example requires python3. It was implemented with python 3.5. 

> **WARNING**: The newest supported version is python 3.7.9. For newer python versions there are some requirements that are not available in the needed version.
```
pip3 install -r requirements.txt
```

In this example code, we show how to use visdom for live visualization. See the Trixi documentation for more details or information about other tools like tensorboard.
After setting up the virtual environment you have to start visdom once so it can download some needed files. You only
have to do that once. You can stop the visdom server after a few seconds when it finished downloading the files.
```
python3 -m visdom.server
```

You can edit the paths for data storage and logging in the config file. By default, everything is stored in your working directory.


## How to use it
To start the training simply run 
```
python3 run_train_pipeline.py
```

This will download the Hippocampus dataset from the medical segmentation decathlon (http://medicaldecathlon.com),
extract and preprocess it and then start the training. The preprocessing loads the images (imagesTr) and the corresponding labels (labelsTr), performs some normalization and padding operations and saves the data as NPY files. The available images are then split into `train`, `validation` and `test` sets.
The splits are saved to a `splits.pkl` file. The images in `imagesTs` are  not used in the example, because they are the test set for the medical segmentation decathlon and
 therefore no ground truth is provided.

If you run the pipeline again, the dataset will not be downloaded, extracted or preprocessed again. To enforce it, just delete the folder.

The training process will automatically be visualized using trixi/visdom. After starting the training you navigate in your browser to the port which is printed by the training script. Then you should see your loss curve and so on.

By default, a 2-dimensional U-Net is used. The example also comes with a 3-D version of the network (Özgün Cicek et al.).
To use the 3-D version, simple use
```
python train3D.py
```

> **WARNING**: The 3-D version is not yet tested thoroughly. Use it with caution!

## How to use it for your own data
This description is work in progress. If you use this repo for your own data please share your experience, so we can update this part.

### Config

The included `Config_unet.py` is an example config file. You have to adapt this to fit your local environment, e.g., if you run out of CUDA memory, try to reduce `batch_size` or
 `patch_size`. All the other parameters should be self-explanatory or described directly in the code comments. 

Choose the `#Train parameters` to fit both, your data and your workstation. 
With `fold` you can choose which split from your `splits.pkl` you want to use for the training.

You may also need to adapt the paths (`data_root_dir, data_dir, data_test_dir and split_dir`).

You can change the `Logging parameters` if you want to. With `append_rnd_string`, you can give each experiment you start a unique name.
If you want to start your visdom server manually, just set `start_visdom=False`. If you do not want to use visdom logging at all, just remove the visdom logger from your
 experiment, e.g. `run_train_pipeline.py` line 47:
 
 ```
 loggers={
       "visdom": ("visdom", {"auto_start": c.start_visdom})
 }
 ```

### Datasets
If you want to use the provided DataLoader, you need to preprocess your data appropriately. An example can be found in the 
"example_dataset" folder. Make sure to load your images and your labels as numpy arrays. The required shape is `(#slices, w,h)`. 
Then save both using:
```
result = np.stack((image, label))

np.save(output_filename, result)
```

The provided DataLoader requires a splits.pkl file, that contains a dictionary of all the files used for training, validation and testing.
It looks like this:
```
[{'train': ['dataset_name_1',...], 'val': ['dataset_name_2', ...], 'test': ['dataset_name_3', ...]}]
```

We use the `MIC/batchgenerators` to perform data augmentation. The example uses cropping, mirroring and some elastic spatial transformation.
You can change the data augmentation by editing the `data_augmentation.py`. Please see the `MIC/batchgenerators` documentation for more details.

To train your network, simply run
```
python train.py
```

You can either edit the config file or add command line parameters like this:
```
python train.py --n_epochs 100 [...]
```

## Networks
This example contains a simple implementation of the U-Net [1], which can be found in `networks>UNET.py`. 
A little more generic version of the U-Net, as well as the 3D U-Net [3], can be found in `networks>RecursiveUNet.py` 
respectively `networks>RecursiveUNet3D.py`. This implementation is done recursively.
It is therefore very easy to configure the number of downsamplings. Also, the type of normalization can be passed as a parameter (default is nn.InstanceNorm2d).

## Errors and how to handle them
In this section, we want to collect common errors that may occur when using this repository.
If you encounter something, feel free to let us know about it and we will include it here.

### Windows related issues

If you want to use this repo on Windows, please note, that you have to adapt to some things.
We recommend to install PyTorch via conda on Windows using: `python -m conda install pytorch torchvision cpuonly -c pytorch`
You then have to remove torch from the requirements.txt.

If you run into issues like the following one:

 ```
AttributeError: Can't pickle local object 'MultiThreadedDataLoader.get_worker_init_fn.<locals>.init_fn'`
 ```

try to use SingleProcessDataLoader instead. This error is probably caused by how multithreading is handled in python on Windows.
So fix this, add `num_processes=0` to your dataloaders:

 ```
self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, 
                                        batch_size=self.config.batch_size, keys=tr_keys, num_processes=0)
self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, 
                                        batch_size=self.config.batch_size, keys=val_keys, mode="val", do_reshuffle=False, num_processes=0)
self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, 
                                        batch_size=self.config.batch_size, keys=test_keys, mode="test", do_reshuffle=False, num_processes=0)
 ```

### Multiple Labels
Depending on your dataset you might be dealing with multiple labels. For example the
data from BRATS (https://www.med.upenn.edu/sbia/brats2017.html) has the following labels:
 ```
 "labels": {
	 "0": "background",
	 "1": "edema",
	 "2": "non-enhancing tumor",
	 "3": "enhancing tumour"
 },
 ```
* If you run into an error like this:
    ```
    Experiment exited. Checkpoints stored =)
    INFO:default-z3HafHO4CS:Experiment exited. Checkpoints stored =)
    Unhandled exception in thread started by <function PytorchExperimentLogger.save_checkpoint_static at 0x7fd07c3e8510>
    Traceback (most recent call last):
      File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 196, in save_checkpoint_static
       torch.save(to_cpu(kwargs), checkpoint_file)
      File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in to_cpu
        return {key: to_cpu(val) for key, val in obj.items()}
      File "//python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in <dictcomp>
        return {key: to_cpu(val) for key, val in obj.items()}
      File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in to_cpu
        return {key: to_cpu(val) for key, val in obj.items()}
      File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 191, in <dictcomp>
        return {key: to_cpu(val) for key, val in obj.items()}
      File "/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py", line 189, in to_cpu
        return obj.cpu()
    RuntimeError: CUDA error: device-side assert triggered
    ```
    make sure you updated `num_classes` in your config file. The value of `num_classes` should always
    equal the number of your labels including background.

* If you run into an error like this:
    ```
    File "/home/student/basic_unet/trixi/trixi/experiment/experiment.py", line 108, in run
      self.process_err(e)
    File "/home/student/basic_unet/trixi/trixi/experiment/pytorchexperiment.py", line 391, in process_err
      raise e
    File "/home/student/basic_unet/trixi/trixi/experiment/experiment.py", line 89, in run
      self.train(epoch=self._epoch_idx)
    File "/home/student/PycharmProjects/new_unet/experiments/UNetExperiment.py", line 113, in train
      loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
    File "/opt/anaconda3/envs/a_new_test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in call
      result = self.forward(input, *kwargs)
    File "/home/student/PycharmProjects/new_unet/loss_functions/dice_loss.py", line 125, in forward
      yonehot.scatter(1, y, 1)
    RuntimeError: Invalid index in scatter at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:551
    ```
    make sure to check your labels again. The error may be caused by the fact that the labels are not sequential. This causes `scatter` to crash. Consider changing the values of      your labels.

## References
[1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." 
International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
[2] David Zimmerer, Jens Petersen, GregorKoehler, Jakob Wasserthal, dzimmm, Tim, … André Pequeño. (2018, November 23). MIC-DKFZ/trixi: Alpha (Version v0.1.1). 
Zenodo. http://doi.org/10.5281/zenodo.1495180
[3] Çiçek, Özgün, et al. "3D U-Net: learning dense volumetric segmentation from sparse annotation." 
International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.



================================================
FILE: __init__.py
================================================


================================================
FILE: configs/Config_unet.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from trixi.util import Config


def get_config():
    # Set your own path, if needed.
    data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.

        # Train parameters
        num_classes=3,
        in_channels=1,
        batch_size=8,
        patch_size=64,
        n_epochs=10,
        learning_rate=0.0002,
        fold=0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.

        device="cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        author='kleina',  # Author of this project
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,  # Appends a random string to the experiment name to make it unique.
        start_visdom=True,  # You can either start a visom server manually or have trixi start it for you.

        do_instancenorm=True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',  # This id is used to download the example dataset.
        dataset_name='Task04_Hippocampus',
        base_dir=os.path.abspath('output_experiment'),  # Where to log the output of the experiment.

        data_root_dir=data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your test data is stored

        split_dir=os.path.join(data_root_dir, 'Task04_Hippocampus'),  # This is where the 'splits.pkl' file is located, that holds your splits.

        # execute a segmentation process on a specific image using the model
        model_dir=os.path.join(os.path.abspath('output_experiment'), ''),  # the model being used for segmentation
    )

    print(c)
    return c


================================================
FILE: configs/Config_unet_spleen.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from trixi.util import Config


def get_config():

    # Set your own path, if needed.
    data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.

        # Train parameters
        num_classes=2,
        in_channels=1,
        batch_size=3,       # works with 6 on GB GPU
        patch_size=512,
        n_epochs=1,
        learning_rate=0.0002,
        fold=0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.

        device="cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        author='kleina',  # Author of this project
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,  # Appends a random string to the experiment name to make it unique.
        start_visdom=True,  # You can either start a visom server manually or have trixi start it for you.

        do_instancenorm=True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id='1jzeNU1EKnK81PyTsrx0ujfNl-t0Jo8uE', #spleen
        dataset_name='Task09_Spleen',
        base_dir=os.path.abspath('output_experiment'),  # Where to log the output of the experiment.

        data_root_dir=data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'),  # This is where your test data is stored

        split_dir=os.path.join(data_root_dir, 'Task09_Spleen'),  # This is where the 'splits.pkl' file is located, that holds your splits.

        # execute a segmentation process on a specific image using the model
        model_dir=os.path.join(os.path.abspath('output_experiment'), '20200108-035420_Basic_Unet/checkpoint/checkpoint_current'),   # the model being used for segmentation
    )

    print(c)
    return c


================================================
FILE: configs/__init__.py
================================================


================================================
FILE: datasets/__init__.py
================================================


================================================
FILE: datasets/data_loader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import DataLoader, Dataset
from trixi.util.pytorchutils import set_seed


class WrappedDataset(Dataset):
    def __init__(self, dataset, transform):
        self.transform = transform
        self.dataset = dataset

        self.is_indexable = False
        if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next is True):
            self.is_indexable = True

    def __getitem__(self, index):

        if not self.is_indexable:
            item = next(self.dataset)
        else:
            item = self.dataset[index]
        item = self.transform(**item)
        return item

    def __len__(self):
        return int(self.dataset.num_batches)


class MultiThreadedDataLoader(object):
    def __init__(self, data_loader, transform, num_processes, **kwargs):

        self.cntr = 1
        self.ds_wrapper = WrappedDataset(data_loader, transform)

        self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                                    num_workers=num_processes, pin_memory=True, drop_last=False,
                                    worker_init_fn=self.get_worker_init_fn())

        self.num_processes = num_processes
        self.iter = None

    def get_worker_init_fn(self):
        def init_fn(worker_id):
            set_seed(worker_id + self.cntr)

        return init_fn

    def __iter__(self):
        self.kill_iterator()
        self.iter = iter(self.generator)
        return self.iter

    def __next__(self):
        if self.iter is None:
            self.iter = iter(self.generator)
        return next(self.iter)

    def renew(self):
        self.cntr += 1
        self.kill_iterator()
        self.generator.worker_init_fn = self.get_worker_init_fn()
        self.iter = iter(self.generator)

    def restart(self):
        pass
        # self.iter = iter(self.generator)

    def kill_iterator(self):
        try:
            if self.iter is not None:
                self.iter._shutdown_workers()
                for p in self.iter.workers:
                    p.terminate()
        except:
            print("Could not kill Dataloader Iterator")


================================================
FILE: datasets/example_dataset/__init__.py
================================================


================================================
FILE: datasets/example_dataset/create_splits.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from utilities.file_and_folder_operations import subfiles

import os
import random


def create_splits(output_dir, image_dir):
    """File to split the dataset into multiple folds and the train, validation and test set.

    :param output_dir: Directory to write the splits file to
    :param image_dir: Directory where the images lie in.
    """
    npy_files = subfiles(image_dir, suffix=".npy", join=False)
    sample_size = len(npy_files)

    testset_size = int(sample_size * 0.25)
    valset_size = int(sample_size * 0.25)
    trainset_size = sample_size - valset_size - testset_size  # Assure all samples are used.

    if sample_size < (trainset_size + valset_size + testset_size):
        raise ValueError("Assure more total samples exist than train test and val samples combined!")

    splits = []
    sample_set = {sample[:-4] for sample in npy_files.copy()}  # Remove the file extension
    test_samples = random.sample(sample_set, testset_size)  # IMO the Testset should be static for all splits

    for split in range(0, 5):
        train_samples = random.sample(sample_set - set(test_samples), trainset_size)
        val_samples = list(sample_set - set(train_samples) - set(test_samples))

        train_samples.sort()
        val_samples.sort()

        split_dict = dict()
        split_dict['train'] = train_samples
        split_dict['val'] = val_samples
        split_dict['test'] = test_samples

        splits.append(split_dict)

    # Todo: IMO it is better to write that dict as JSON. This (unlike pickle) allows the user to inspect the file with an editor
    with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:
        pickle.dump(splits, f)

    splits_sanity_check(output_dir)


# ToDo: The naming "splits.pkl should not be distributed over multiple files. This makes changing of it less clear.
#   Instead move saving and loading to one file. (Here would be a good place)
#   Other usages are: spleen/create_splits.py:57 (Which is redundand anyways?);
#   UNetExperiment3D.py:55  and UNetExperiment.py:55
def splits_sanity_check(path):
    """ Takes path to a splits file and verifies that no samples from the test dataset leaked into train or validation.
    :param path
    """
    with open(os.path.join(path, 'splits.pkl'), 'rb') as f:
        splits = pickle.load(f)
        for i in range(len(splits)):
            samples = splits[i]
            tr_samples = set(samples["train"])
            vl_samples = set(samples["val"])
            ts_samples = set(samples["test"])

            assert len(tr_samples.intersection(vl_samples)) == 0, "Train and validation samples overlap!"
            assert len(vl_samples.intersection(ts_samples)) == 0, "Validation and Test samples overlap!"
            assert len(tr_samples.intersection(ts_samples)) == 0, "Train and Test samples overlap!"
    return


================================================
FILE: datasets/example_dataset/preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

from batchgenerators.augmentations.utils import pad_nd_image
from medpy.io import load
import os
import numpy as np
import torch

from utilities.file_and_folder_operations import subfiles


def preprocess_data(root_dir, y_shape=64, z_shape=64):
    image_dir = os.path.join(root_dir, 'imagesTr')
    label_dir = os.path.join(root_dir, 'labelsTr')
    output_dir = os.path.join(root_dir, 'preprocessed')
    classes = 3

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print('Created' + output_dir + '...')

    class_stats = defaultdict(int)
    total = 0

    nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)

    for i in range(0, len(nii_files)):
        if nii_files[i].startswith("._"):
            nii_files[i] = nii_files[i][2:]

    for f in nii_files:
        image, _ = load(os.path.join(image_dir, f))
        label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))

        print(f)

        for i in range(classes):
            class_stats[i] += np.sum(label == i)
            total += np.sum(label == i)

        # normalize images
        image = (image - image.min())/(image.max()-image.min())

        image = pad_nd_image(image, (image.shape[0], y_shape, z_shape), "constant", kwargs={'constant_values': image.min()})
        label = pad_nd_image(label, (image.shape[0], y_shape, z_shape), "constant", kwargs={'constant_values': label.min()})

        result = np.stack((image, label))

        np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)
        print(f)

    print(total)
    for i in range(classes):
        print(class_stats[i], class_stats[i]/total)


def preprocess_single_file(image_file):
    image, image_header = load(image_file)
    image = (image - image.min()) / (image.max() - image.min())

    data = np.expand_dims(image, 1)

    return torch.from_numpy(data), image_header


def postprocess_single_image(image):
    # desired shape is [b w h]
    result_converted = image[::, 0, ::, ::]
    result_mapped = [i * 255 for i in result_converted]

    return result_mapped


================================================
FILE: datasets/spleen/__init__.py
================================================


================================================
FILE: datasets/spleen/create_splits.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from utilities.file_and_folder_operations import subfiles

import os
import numpy as np


def create_splits(output_dir, image_dir):
    npy_files = subfiles(image_dir, suffix=".npy", join=False)

    trainset_size = len(npy_files)*50//100
    valset_size = len(npy_files)*25//100
    testset_size = len(npy_files)*25//100

    splits = []
    for split in range(0, 5):
        image_list = npy_files.copy()
        trainset = []
        valset = []
        testset = []
        for i in range(0, trainset_size):
            patient = np.random.choice(image_list)
            image_list.remove(patient)
            trainset.append(patient[:-4])
        for i in range(0, valset_size):
            patient = np.random.choice(image_list)
            image_list.remove(patient)
            valset.append(patient[:-4])
        for i in range(0, testset_size):
            patient = np.random.choice(image_list)
            image_list.remove(patient)
            testset.append(patient[:-4])
        split_dict = dict()
        split_dict['train'] = trainset
        split_dict['val'] = valset
        split_dict['test'] = testset

        splits.append(split_dict)

    with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:
        pickle.dump(splits, f)


================================================
FILE: datasets/spleen/preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

from medpy.io import load
import os
import numpy as np

from utilities.file_and_folder_operations import subfiles
import torch


def preprocess_data(root_dir, y_shape=64, z_shape=64):
    image_dir = os.path.join(root_dir, 'imagesTr')
    label_dir = os.path.join(root_dir, 'labelsTr')
    output_dir = os.path.join(root_dir, 'preprocessed')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print('Created' + output_dir + '...')

    class_stats = defaultdict(int)
    total = 0

    nii_files = subfiles(image_dir, suffix=".nii.gz", join=False)

    for i in range(0, len(nii_files)):
        if nii_files[i].startswith("._"):
            nii_files[i] = nii_files[i][2:]

    for f in nii_files:
        image, _ = load(os.path.join(image_dir, f))
        label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))

        print(f)

        # normalize images
        image = (image - image.min())/(image.max()-image.min())

        image = np.swapaxes(image, 0, 2)
        image = np.swapaxes(image, 1, 2)

        label = np.swapaxes(label, 0, 2)
        label = np.swapaxes(label, 1, 2)
        result = np.stack((image, label))

        np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)
        print(f)

    print(total)


def preprocess_single_file(image_file):
    image, image_header = load(image_file)
    image = (image - image.min()) / (image.max() - image.min())

    image = np.swapaxes(image, 0, 2)
    image = np.swapaxes(image, 1, 2)

    # TODO check original shape and reshape data if necessary
    # image = reshape(image, append_value=0, new_shape=(image.shape[0], y_shape, z_shape))
    # numpy_array = np.array(image)

    # Image shape is [b, w, h] and has one channel only
    # Desired shape = [b, c, w, h]
    # --> expand to have only one channel c=1 - data is in desired shape
    data = np.expand_dims(image, 1)

    return torch.from_numpy(data), image_header


def postprocess_single_image(image):
    # desired shape is [b w h]
    result_converted = image[::, 0, ::, ::]
    result_mapped = [i * 255 for i in result_converted]

    # swap axes back, like we were supposed to do so
    result_mapped = np.swapaxes(result_mapped, 2, 1)
    result_mapped = np.swapaxes(result_mapped, 2, 0)

    return result_mapped


================================================
FILE: datasets/three_dim/NumpyDataLoader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import fnmatch
import random

import numpy as np

from batchgenerators.dataloading import SlimDataLoaderBase
from datasets.data_loader import MultiThreadedDataLoader
from .data_augmentation import get_transforms


def load_dataset(base_dir, pattern='*.npy', keys=None):
    fls = []
    files_len = []
    dataset = []

    for root, dirs, files in os.walk(base_dir):
        i = 0
        for filename in sorted(fnmatch.filter(files, pattern)):

            if keys is not None and filename[:-4] in keys:
                npy_file = os.path.join(root, filename)
                numpy_array = np.load(npy_file, mmap_mode="r")

                fls.append(npy_file)
                files_len.append(numpy_array.shape[1])

                dataset.extend([i])

                i += 1

    return fls, files_len, dataset


class NumpyDataSet(object):
    """
    TODO
    """
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
                 file_pattern='*.npy', label=1, input=(0,), do_reshuffle=True, keys=None):

        data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern,
                                      input=input, label=label, keys=keys)

        self.data_loader = data_loader
        self.batch_size = batch_size
        self.do_reshuffle = do_reshuffle
        self.number_of_slices = 1

        self.transforms = get_transforms(mode=mode, target_size=target_size)
        self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,
                                                 num_cached_per_queue=num_cached_per_queue, seeds=seed,
                                                 shuffle=do_reshuffle)
        self.augmenter.restart()

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        if self.do_reshuffle:
            self.data_loader.reshuffle()
        self.augmenter.renew()
        return self.augmenter

    def __next__(self):
        return next(self.augmenter)


class NumpyDataLoader(SlimDataLoaderBase):
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
                 seed=None, file_pattern='*.npy', label=1, input=(0,), keys=None):

        self.files, self.file_len, self.dataset = load_dataset(base_dir=base_dir, pattern=file_pattern, keys=keys, )
        super(NumpyDataLoader, self).__init__(self.dataset, batch_size, num_batches)

        self.batch_size = batch_size

        self.use_next = False
        if mode == "train":
            self.use_next = False

        self.idxs = list(range(0, len(self.dataset)))

        self.data_len = len(self.dataset)

        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)

        if isinstance(label, int):
            label = (label,)
        self.input = input
        self.label = label

        self.np_data = np.asarray(self.dataset)

    def reshuffle(self):
        print("Reshuffle...")
        random.shuffle(self.idxs)
        print("Initializing... this might take a while...")

    def generate_train_batch(self):
        open_arr = random.sample(self._data, self.batch_size)
        return self.get_data_from_array(open_arr)

    def __len__(self):
        n_items = min(self.data_len // self.batch_size, self.num_batches)
        return n_items

    def __getitem__(self, item):
        idxs = self.idxs
        data_len = len(self.dataset)
        np_data = self.np_data

        if item > len(self):
            raise StopIteration()
        if (item * self.batch_size) == data_len:
            raise StopIteration()

        start_idx = (item * self.batch_size) % data_len
        stop_idx = ((item + 1) * self.batch_size) % data_len

        if ((item + 1) * self.batch_size) == data_len:
            stop_idx = data_len

        if stop_idx > start_idx:
            idxs = idxs[start_idx:stop_idx]
        else:
            raise StopIteration()

        open_arr = np_data[idxs]

        return self.get_data_from_array(open_arr)

    def get_data_from_array(self, open_array):
        data = []
        fnames = []
        idxs = []
        labels = []

        for idx in open_array:
            fn_name = self.files[idx]

            numpy_array = np.load(fn_name, mmap_mode="r")

            data.append(numpy_array[list(self.input)])   # 'None' keeps the dimension

            if self.label is not None:
                labels.append(numpy_array[list(self.input)])   # 'None' keeps the dimension

            fnames.append(self.files[idx])
            idxs.append(idx)

        ret_dict = {'data': data, 'fnames': fnames, 'idxs': idxs}
        if self.label is not None:
            ret_dict['seg'] = labels

        return ret_dict


================================================
FILE: datasets/three_dim/__init__.py
================================================


================================================
FILE: datasets/three_dim/data_augmentation.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from batchgenerators.transforms import Compose, MirrorTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform
from batchgenerators.transforms.utility_transforms import NumpyToTensor


def get_transforms(mode="train", target_size=128):
    transform_list = []

    if mode == "train":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          MirrorTransform(axes=(2,)),
                          SpatialTransform(patch_size=(target_size, target_size, target_size), random_crop=False,
                                           patch_center_dist_from_border=target_size // 2,
                                           do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                           do_rotation=True,
                                           angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                           scale=(0.9, 1.4),
                                           border_mode_data="nearest", border_mode_seg="nearest"),
                          ]

    elif mode == "val":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          ]

    elif mode == "test":
        transform_list = [CenterCropTransform(crop_size=target_size),
                          ResizeTransform(target_size=target_size, order=1),
                          ]

    transform_list.append(NumpyToTensor())

    return Compose(transform_list)


================================================
FILE: datasets/two_dim/NumpyDataLoader.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import fnmatch
import random

import numpy as np

from batchgenerators.dataloading import SlimDataLoaderBase
from datasets.data_loader import MultiThreadedDataLoader
from .data_augmentation import get_transforms


def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):
    fls = []
    files_len = []
    slices_ax = []

    for root, dirs, files in os.walk(base_dir):
        i = 0
        for filename in sorted(fnmatch.filter(files, pattern)):

            if keys is not None and filename[:-4] in keys:
                npy_file = os.path.join(root, filename)
                numpy_array = np.load(npy_file, mmap_mode="r")

                fls.append(npy_file)
                files_len.append(numpy_array.shape[1])

                slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)])

                i += 1

    return fls, files_len, slices_ax,


class NumpyDataSet(object):
    """
    TODO
    """
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
                 file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None):

        data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, file_pattern=file_pattern,
                                      input_slice=input_slice, label_slice=label_slice, keys=keys)

        self.data_loader = data_loader
        self.batch_size = batch_size
        self.do_reshuffle = do_reshuffle
        self.number_of_slices = 1

        self.transforms = get_transforms(mode=mode, target_size=target_size)
        self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,
                                                 num_cached_per_queue=num_cached_per_queue,
                                                 shuffle=do_reshuffle)
        self.augmenter.restart()

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        if self.do_reshuffle:
            self.data_loader.reshuffle()
        self.augmenter.renew()
        return self.augmenter

    def __next__(self):
        return next(self.augmenter)


class NumpyDataLoader(SlimDataLoaderBase):
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
                 file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None):

        self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, )
        super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches)

        self.batch_size = batch_size

        self.use_next = False
        if mode == "train":
            self.use_next = False

        self.slice_idxs = list(range(0, len(self.slices)))

        self.data_len = len(self.slices)

        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)

        if isinstance(label_slice, int):
            label_slice = (label_slice,)
        self.input_slice = input_slice
        self.label_slice = label_slice

        self.np_data = np.asarray(self.slices)

    def reshuffle(self):
        print("Reshuffle...")
        random.shuffle(self.slice_idxs)
        print("Initializing... this might take a while...")

    def generate_train_batch(self):
        open_arr = random.sample(self._data, self.batch_size)
        return self.get_data_from_array(open_arr)

    def __len__(self):
        n_items = min(self.data_len // self.batch_size, self.num_batches)
        return n_items

    def __getitem__(self, item):
        slice_idxs = self.slice_idxs
        data_len = len(self.slices)
        np_data = self.np_data

        if item > len(self):
            raise StopIteration()
        if (item * self.batch_size) == data_len:
            raise StopIteration()

        start_idx = (item * self.batch_size) % data_len
        stop_idx = ((item + 1) * self.batch_size) % data_len

        if ((item + 1) * self.batch_size) == data_len:
            stop_idx = data_len

        if stop_idx > start_idx:
            idxs = slice_idxs[start_idx:stop_idx]
        else:
            raise StopIteration()

        open_arr = np_data[idxs]

        return self.get_data_from_array(open_arr)

    def get_data_from_array(self, open_array):
        data = []
        fnames = []
        slice_idxs = []
        labels = []

        for slice in open_array:
            fn_name = self.files[slice[0]]

            numpy_array = np.load(fn_name, mmap_mode="r")

            numpy_slice = numpy_array[:, slice[1], ]
            data.append(numpy_slice[list(self.input_slice)])   # 'None' keeps the dimension

            if self.label_slice is not None:
                labels.append(numpy_slice[list(self.label_slice)])   # 'None' keeps the dimension

            fnames.append(self.files[slice[0]])
            slice_idxs.append(slice[1])

        ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs}
        if self.label_slice is not None:
            ret_dict['seg'] = np.asarray(labels)

        return ret_dict


================================================
FILE: datasets/two_dim/__init__.py
================================================


================================================
FILE: datasets/two_dim/data_augmentation.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from batchgenerators.transforms import Compose, MirrorTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform
from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform
from batchgenerators.transforms.utility_transforms import NumpyToTensor

import numpy as np


def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=(target_size,target_size), order=1),
                         MirrorTransform(axes=(1,)),
                         SpatialTransform(patch_size=(target_size, target_size), random_crop=False,
                                          patch_center_dist_from_border=target_size // 2,
                                          do_elastic_deform=True, alpha=(0., 900.), sigma=(20., 30.),
                                          do_rotation=True, p_rot_per_sample=0.8,
                                          angle_x=(-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                          scale=(0.85, 1.25), p_scale_per_sample=0.8,
                                          border_mode_data="nearest", border_mode_seg="nearest"),
                         ]


    elif mode == "val":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)


================================================
FILE: datasets/utils.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os.path import exists
import tarfile

from google_drive_downloader import GoogleDriveDownloader as gdd


def download_dataset(dest_path, dataset, id=''):
    if not exists(os.path.join(dest_path, dataset)):
        tar_path = os.path.join(dest_path, dataset) + '.tar'
        gdd.download_file_from_google_drive(file_id=id,
                                            dest_path=tar_path, overwrite=False,
                                            unzip=False)

        print('Extracting data [STARTED]')
        tar = tarfile.open(tar_path)
        tar.extractall(dest_path)
        print('Extracting data [DONE]')
    else:
        print('Data already downloaded. Files are not extracted again.')
        print('Data already downloaded. Files are not extracted again.')

    return


================================================
FILE: evaluation/__init__.py
================================================


================================================
FILE: evaluation/evaluator.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import collections
import inspect
import json
import hashlib
from datetime import datetime
import numpy as np
import pandas as pd
import SimpleITK as sitk
from evaluation.metrics import ConfusionMatrix, ALL_METRICS


class Evaluator:
    """Object that holds test and reference segmentations with label information
    and computes a number of metrics on the two. 'labels' must either be an
    iterable of numeric values (or tuples thereof) or a dictionary with string
    names and numeric values.
    """

    default_metrics = [
        "False Positive Rate",
        "Dice",
        "Jaccard",
        "Precision",
        "Recall",
        "Accuracy",
        "False Omission Rate",
        "Negative Predictive Value",
        "False Negative Rate",
        "True Negative Rate",
        "False Discovery Rate",
        "Total Positives Test",
        "Total Positives Reference"
    ]

    default_advanced_metrics = [
        "Hausdorff Distance",
        "Hausdorff Distance 95",
        "Avg. Surface Distance",
        "Avg. Symmetric Surface Distance"
    ]

    def __init__(self,
                 test=None,
                 reference=None,
                 labels=None,
                 metrics=None,
                 advanced_metrics=None,
                 nan_for_nonexisting=True):

        self.test = None
        self.reference = None
        self.confusion_matrix = ConfusionMatrix()
        self.labels = None
        self.nan_for_nonexisting = nan_for_nonexisting
        self.result = None

        self.metrics = []
        if metrics is None:
            for m in self.default_metrics:
                self.metrics.append(m)
        else:
            for m in metrics:
                self.metrics.append(m)

        self.advanced_metrics = []
        if advanced_metrics is None:
            for m in self.default_advanced_metrics:
                self.advanced_metrics.append(m)
        else:
            for m in advanced_metrics:
                self.advanced_metrics.append(m)

        self.set_reference(reference)
        self.set_test(test)
        if labels is not None:
            self.set_labels(labels)
        else:
            if test is not None and reference is not None:
                self.construct_labels()

    def set_test(self, test):
        """Set the test segmentation."""

        self.test = test

    def set_reference(self, reference):
        """Set the reference segmentation."""

        self.reference = reference

    def set_labels(self, labels):
        """Set the labels.
        :param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels
        will only have names if you pass a dictionary"""

        if not isinstance(labels, (dict, set, list, tuple)):
            raise ValueError("Labels must be either list, tuple, set or dict")
        elif isinstance(labels, dict):
            self.labels = collections.OrderedDict(labels)
        elif isinstance(labels, set):
            self.labels = list(labels)
        elif isinstance(labels, (list, tuple)):
            self.labels = labels
        else:
            raise TypeError("Can only handle dict, list, tuple, set & numpy array, but input is of type {}".format(type(labels)))

    def construct_labels(self):
        """Construct label set from unique entries in segmentations."""

        if self.test is None and self.reference is None:
            raise ValueError("No test or reference segmentations.")
        elif self.test is None:
            labels = np.unique(self.reference)
        else:
            labels = np.union1d(np.unique(self.test),
                                np.unique(self.reference))
        self.labels = list(map(lambda x: int(x), labels))

    def set_metrics(self, metrics):
        """Set evaluation metrics"""

        if isinstance(metrics, set):
            self.metrics = list(metrics)
        elif isinstance(metrics, (list, tuple, np.ndarray)):
            self.metrics = metrics
        else:
            raise TypeError("Can only handle list, tuple, set & numpy array, but input is of type {}".format(type(metrics)))

    def add_metric(self, metric):

        if metric not in self.metrics:
            self.metrics.append(metric)

    def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs):
        """Compute metrics for segmentations."""
        if test is not None:
            self.set_test(test)

        if reference is not None:
            self.set_reference(reference)

        if self.test is None or self.reference is None:
            raise ValueError("Need both test and reference segmentations.")

        if self.labels is None:
            self.construct_labels()

        self.metrics.sort()

        # get functions for evaluation
        # somewhat convoluted, but allows users to define additonal metrics
        # on the fly, e.g. inside an IPython console
        _funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics}
        frames = inspect.getouterframes(inspect.currentframe())
        for metric in self.metrics:
            for f in frames:
                if metric in f[0].f_locals:
                    _funcs[metric] = f[0].f_locals[metric]
                    break
            else:
                if metric in _funcs:
                    continue
                else:
                    raise NotImplementedError(
                        "Metric {} not implemented.".format(metric))

        # get results
        self.result = {}

        eval_metrics = self.metrics
        if advanced:
            eval_metrics += self.advanced_metrics

        if isinstance(self.labels, dict):

            for label, name in self.labels.items():
                self.result[name] = {}
                if not hasattr(label, "__iter__"):
                    self.confusion_matrix.set_test(self.test == label)
                    self.confusion_matrix.set_reference(self.reference == label)
                else:
                    current_test = 0
                    current_reference = 0
                    for l in label:
                        current_test += (self.test == l)
                        current_reference += (self.reference == l)
                    self.confusion_matrix.set_test(current_test)
                    self.confusion_matrix.set_reference(current_reference)
                for metric in eval_metrics:
                    self.result[name][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
                                                               nan_for_nonexisting=self.nan_for_nonexisting,
                                                               **metric_kwargs)

        else:

            for i, l in enumerate(self.labels):
                self.result[l] = {}
                self.confusion_matrix.set_test(self.test == l)
                self.confusion_matrix.set_reference(self.reference == l)
                for metric in eval_metrics:
                    self.result[l][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
                                                            nan_for_nonexisting=self.nan_for_nonexisting,
                                                            **metric_kwargs)

        return self.result

    def to_dict(self):

        if self.result is None:
            self.evaluate()
        return self.result

    def to_array(self):
        """Return result as numpy array (labels x metrics)."""

        if self.result is None:
            self.evaluate

        result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())

        a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32)

        if isinstance(self.labels, dict):
            for i, label in enumerate(self.labels.keys()):
                for j, metric in enumerate(result_metrics):
                    a[i][j] = self.result[self.labels[label]][metric]
        else:
            for i, label in enumerate(self.labels):
                for j, metric in enumerate(result_metrics):
                    a[i][j] = self.result[label][metric]

        return a

    def to_pandas(self):
        """Return result as pandas DataFrame."""

        a = self.to_array()

        if isinstance(self.labels, dict):
            labels = list(self.labels.values())
        else:
            labels = self.labels

        result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())

        return pd.DataFrame(a, index=labels, columns=result_metrics)


class NiftiEvaluator(Evaluator):

    def __init__(self, *args, **kwargs):

        self.test_nifti = None
        self.reference_nifti = None
        super(NiftiEvaluator, self).__init__(*args, **kwargs)

    def set_test(self, test):
        """Set the test segmentation."""

        if test is not None:
            self.test_nifti = sitk.ReadImage(test)
            super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti))
        else:
            self.test_nifti = None
            super(NiftiEvaluator, self).set_test(test)

    def set_reference(self, reference):
        """Set the reference segmentation."""

        if reference is not None:
            self.reference_nifti = sitk.ReadImage(reference)
            super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti))
        else:
            self.reference_nifti = None
            super(NiftiEvaluator, self).set_reference(reference)

    def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs):

        if voxel_spacing is None:
            voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1]
            metric_kwargs["voxel_spacing"] = voxel_spacing

        return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs)


def aggregate_scores(test_ref_pairs,
                     evaluator=NiftiEvaluator,
                     labels=None,
                     nanmean=True,
                     json_output_file=None,
                     json_name="",
                     json_description="",
                     json_author="Fabian",
                     json_task="",
                     **metric_kwargs):
    """
    test = predicted image
    :param test_ref_pairs:
    :param evaluator:
    :param labels: must be a dict of int-> str or a list of int
    :param nanmean:
    :param json_output_file:
    :param json_name:
    :param json_description:
    :param json_author:
    :param json_task:
    :param metric_kwargs:
    :return:
    """

    if type(evaluator) == type:
        evaluator = evaluator()

    if labels is not None:
        evaluator.set_labels(labels)

    all_scores = {}
    all_scores["all"] = []
    all_scores["mean"] = {}

    for i, (test, ref) in enumerate(test_ref_pairs):

        # evaluate
        evaluator.set_test(test)
        evaluator.set_reference(ref)
        if evaluator.labels is None:
            evaluator.construct_labels()
        current_scores = evaluator.evaluate(**metric_kwargs)
        if type(test) == str:
            current_scores["test"] = test
        if type(ref) == str:
            current_scores["reference"] = ref
        all_scores["all"].append(current_scores)

        # append score list for mean
        for label, score_dict in current_scores.items():
            if label in ("test", "reference"):
                continue
            if label not in all_scores["mean"]:
                all_scores["mean"][label] = {}
            for score, value in score_dict.items():
                if score not in all_scores["mean"][label]:
                    all_scores["mean"][label][score] = []
                all_scores["mean"][label][score].append(value)

    for label in all_scores["mean"]:
        for score in all_scores["mean"][label]:
            if nanmean:
                all_scores["mean"][label][score] = float(np.nanmean(all_scores["mean"][label][score]))
            else:
                all_scores["mean"][label][score] = float(np.mean(all_scores["mean"][label][score]))

    # save to file if desired
    # we create a hopefully unique id by hashing the entire output dictionary
    if json_output_file is not None:
        if type(json_output_file) == str:
            json_output_file = open(json_output_file, "w")
        json_dict = {}
        json_dict["name"] = json_name
        json_dict["description"] = json_description
        timestamp = datetime.today()
        json_dict["timestamp"] = str(timestamp)
        json_dict["task"] = json_task
        json_dict["author"] = json_author
        json_dict["results"] = all_scores
        json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
        json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
        json_output_file.close()

    return all_scores


def aggregate_scores_for_experiment(score_file,
                                    labels=None,
                                    metrics=Evaluator.default_metrics,
                                    nanmean=True,
                                    json_output_file=None,
                                    json_name="",
                                    json_description="",
                                    json_author="Fabian",
                                    json_task=""):

    scores = np.load(score_file)
    scores_mean = scores.mean(0)
    if labels is None:
        labels = list(map(str, range(scores.shape[1])))

    results = []
    results_mean = {}
    for i in range(scores.shape[0]):
        results.append({})
        for l, label in enumerate(labels):
            results[-1][label] = {}
            results_mean[label] = {}
            for m, metric in enumerate(metrics):
                results[-1][label][metric] = float(scores[i][l][m])
                results_mean[label][metric] = float(scores_mean[l][m])

    json_dict = {}
    json_dict["name"] = json_name
    json_dict["description"] = json_description
    timestamp = datetime.today()
    json_dict["timestamp"] = str(timestamp)
    json_dict["task"] = json_task
    json_dict["author"] = json_author
    json_dict["results"] = {"all": results, "mean": results_mean}
    json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
    if json_output_file is not None:
        json_output_file = open(json_output_file, "w")
        json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
        json_output_file.close()

    return json_dict


================================================
FILE: evaluation/metrics.py
================================================
import numpy as np
from medpy import metric


def assert_shape(test, reference):

    assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
        test.shape, reference.shape)


class ConfusionMatrix:

    def __init__(self, test=None, reference=None):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.reference_empty = None
        self.reference_full = None
        self.test_empty = None
        self.test_full = None
        self.set_reference(reference)
        self.set_test(test)

    def set_test(self, test):

        self.test = test
        self.reset()

    def set_reference(self, reference):

        self.reference = reference
        self.reset()

    def reset(self):

        self.tp = None
        self.fp = None
        self.tn = None
        self.fn = None
        self.size = None
        self.test_empty = None
        self.test_full = None
        self.reference_empty = None
        self.reference_full = None

    def compute(self):

        if self.test is None or self.reference is None:
            raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")

        assert_shape(self.test, self.reference)

        self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
        self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
        self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
        self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
        self.size = int(np.product(self.reference.shape))
        self.test_empty = not np.any(self.test)
        self.test_full = np.all(self.test)
        self.reference_empty = not np.any(self.reference)
        self.reference_full = np.all(self.reference)

    def get_matrix(self):

        for entry in (self.tp, self.fp, self.tn, self.fn):
            if entry is None:
                self.compute()
                break

        return self.tp, self.fp, self.tn, self.fn

    def get_size(self):

        if self.size is None:
            self.compute()
        return self.size

    def get_existence(self):

        for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
            if case is None:
                self.compute()
                break

        return self.test_empty, self.test_full, self.reference_empty, self.reference_full


def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """2TP / (2TP + FP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty and reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(2. * tp / (2 * tp + fp + fn))


def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty and reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(tp / (tp + fp + fn))


def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FP)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(tp / (tp + fp))


def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if reference_empty:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(tp / (tp + fn))


def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TP / (TP + FN)"""

    return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)


def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FP)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(tn / (tn + fp))


def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
    """(TP + TN) / (TP + FP + FN + TN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return float((tp + tn) / (tp + fp + tn + fn))


def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):
    """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""

    precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)
    recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)

    return (1 + beta*beta) * precision_ * recall_ /\
        ((beta*beta * precision_) + recall_)


def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FP / (FP + TN)"""

    return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)


def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FN / (TN + FN)"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()
    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0.

    return float(fn / (fn + tn))


def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FN / (TP + FN)"""

    return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)


def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FP)"""

    return specificity(test, reference, confusion_matrix, nan_for_nonexisting)


def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """FP / (TP + FP)"""

    return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)


def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
    """TN / (TN + FN)"""

    return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)


def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TP + FP"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tp + fp


def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TN + FN"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tn + fn


def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TP + FN"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tp + fn


def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
    """TN + FP"""

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    tp, fp, tn, fn = confusion_matrix.get_matrix()

    return tn + fp


def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.hd(test, reference, voxel_spacing, connectivity)


def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.hd95(test, reference, voxel_spacing, connectivity)


def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.asd(test, reference, voxel_spacing, connectivity)


def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):

    if confusion_matrix is None:
        confusion_matrix = ConfusionMatrix(test, reference)

    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()

    if test_empty or test_full or reference_empty or reference_full:
        if nan_for_nonexisting:
            return float("NaN")
        else:
            return 0

    test, reference = confusion_matrix.test, confusion_matrix.reference

    return metric.assd(test, reference, voxel_spacing, connectivity)


ALL_METRICS = {
    "False Positive Rate": false_positive_rate,
    "Dice": dice,
    "Jaccard": jaccard,
    "Hausdorff Distance": hausdorff_distance,
    "Hausdorff Distance 95": hausdorff_distance_95,
    "Precision": precision,
    "Recall": recall,
    "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,
    "Avg. Surface Distance": avg_surface_distance,
    "Accuracy": accuracy,
    "False Omission Rate": false_omission_rate,
    "Negative Predictive Value": negative_predictive_value,
    "False Negative Rate": false_negative_rate,
    "True Negative Rate": true_negative_rate,
    "False Discovery Rate": false_discovery_rate,
    "Total Positives Test": total_positives_test,
    "Total Negatives Test": total_negatives_test,
    "Total Positives Reference": total_positives_reference,
    "total Negatives Reference": total_negatives_reference
}


================================================
FILE: evaluation/readme.md
================================================
# Evaluation Suite

### Metrics

All metrics can be used either by passing test and reference segmentations as
parameters or by passing a `ConfusionMatrix` object. The latter is useful when many
metrics need to be computed, because the relevant computations are only done once.
All metrics assume binary segmentation inputs.

`ConfusionMatrix` has two important methods: `.get_matrix()`, which returns 4 ints for true positives, false positives, true negatives and false negatives, and
`.get_existence()`, which returns 4 bools, indicating whether test and reference
segmentations are all ones or all zeros. The latter is used when you specify
`nan_for_nonexisting=True` in metric calls to return NaN instead of 0 when the result
is undefined, i.e. would require dividing by 0.

### Evaluator

The `Evaluator` is a class that holds one test and one reference segmentation at a time that can contain multiple labels (one-hot encoding is not supported). It also holds a labels attribute than can either be a list of ints (or tuples of ints) or a dictionary
that maps from ints (or tuples of ints) to label names. A typical labels dictionary
could look like this:

```python
labels = {
    1: "Edema",
    2: "Enhancing Tumor",
    3: "Necrosis",
    (1, 2, 3): "Whole Tumor"
}
```

Labels in a tuple will be joined. If no labels are set, they will be automatically constructed from the unique entries in the segmentations upon evaluation. The Evaluator has both a regular set of metrics
that will always be computed and a set of advanced metrics that will only be computed
if `.evaluate(advanced=True)` is passed. The `.evaluate()` method is designed to
look for metric definitions in the current frame, so when you work in an interactive shell and redefine something there (e.g. for testing purposes), the newly defined metric will be used. You can also pass test and reference segmentations directly to evaluate to save calls to `.set_test()` and `.set_reference()`. `.evaluate()` will return a result dictionary and also save it in the `.result` attribute, so you can call `.to_array()` (numpy) or `.to_pandas()` (pandas) later. The resulting shape will be (labels x metrics). `.evaluate()` also takes additional `**metric_kwargs` that will be passed to each metric call.

### NiftiEvaluator

`NiftiEvaluator` redefines the `.set_test()` and `.set_reference()` methods of the `Evaluator` to take path strings instead of arrays. It will read the NIfTI files using SimpleITK, save the SimpleITK images in the `.test_nifti` and `.reference_nifti` attributes and set the arrays as test and reference segmentations. `.evaluate()` has an additional parameter `voxel_spacing`, which should be an iterable of floats. If the parameter is None, the spacing will be automatically read from the SimpleITK images. If you manually read the spacing from SimpleITK images, note that you have to reverse the ordering, because SimpleITK will return (z,y,x) ordering while we expect (x,y,z).

### Evaluating multiple segmentations

If you want to evaluate multiple test/reference pairs and get aggregate statistics, use the `aggregate_scores()` function. It expects an iterable of test/reference pairs and an evaluator (instance or type, will automatically initialize if necessary), which is the `NiftiEvaluator` by default. Test and reference will be set via `.set_test()` and `.set_reference()`, so make sure you're passing the right type for the evaluator. The method will return a dictionary that contains a list of all separate results as well as their mean:

```python
results = {
    "all": [
        {
            "Label": {
                "Metric": float,
                "Metric": float,
                ...
            },
            "Label": {
                ...
            },
            ...
        },
        {
            "Label": ...,
            "Label": ...,
            ...
        },
        ...
    ],
    "mean": {
        "Label": ...
        "Label": ...,
        ...
    }
}
```

`nanmean=True` will use `np.nanmean` instead of `np.mean`. It should be easy to adjust the code to compute arbitrary statistics, but at the moment only mean is supported. If you specify a `json_output_file`, a json file will be written that contains the result dictionary as well as additional information you can specify using the other `json_*` parameters:

```python
json = {
    "name": json_name,  # experiment name, not yours
    "description": json_description,  # a longer description so you know what you did
    "timestamp": "YYYY-MM-DD hh:mm:ss.ffffff",  # automatically generated
    "task": json_task  # the decathlon task
    "author": json_author  # probably Fabian :)
    "results": ...  # the above dictionary
    "id": 001122334455  # hash of other entries as unique id
}
```

`labels` is passed to the evaluator and `**metric_kwargs` is passed to all `.evaluate()` calls.

================================================
FILE: experiments/UNetExperiment.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle

import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F

from datasets.two_dim.NumpyDataLoader import NumpyDataSet
from trixi.experiment.pytorchexperiment import PytorchExperiment

from networks.UNET import UNet
from loss_functions.dice_loss import SoftDiceLoss


class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """

    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                              keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                            keys=val_keys, mode="val", do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                             keys=test_keys, mode="test", do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(batch_dice=True)  # Softmax for DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss()  # No softmax for CE Loss -> is implemented in torch!

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('checkpoint_dir is empty, please provide directory to load checkpoint.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model",))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)
            pred_softmax = F.softmax(pred, dim=1)  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

            loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
            # loss = self.ce_loss(pred, target.squeeze())

            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(self._epoch_idx, loss))

                self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float().cpu(), name="data", normalize=True, scale_each=True, n_iter=epoch)
                self.clog.show_image_grid(target.float().cpu(), name="mask", title="Mask", n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ], name="unt", normalize=True, scale_each=True, n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(pred, dim=1)  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))

        self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)

        self.clog.show_image_grid(data.float().cpu(), name="data_val", normalize=True, scale_each=True, n_iter=epoch)
        self.clog.show_image_grid(target.float().cpu(), name="mask_val", title="Mask", n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ], name="unt_val", normalize=True, scale_each=True, n_iter=epoch)

    def test(self):
        from evaluation.evaluator import aggregate_scores, Evaluator
        from collections import defaultdict

        self.elog.print('=====TEST=====')
        self.model.eval()

        pred_dict = defaultdict(list)
        gt_dict = defaultdict(list)

        batch_counter = 0
        with torch.no_grad():
            for data_batch in self.test_data_loader:
                print('testing...', batch_counter)
                batch_counter += 1

                # Get data_batches
                mr_data = data_batch['data'][0].float().to(self.device)
                mr_target = data_batch['seg'][0].float().to(self.device)

                pred = self.model(mr_data)
                pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)

                fnames = data_batch['fnames']
                for i, fname in enumerate(fnames):
                    pred_dict[fname[0]].append(pred_argmax[i].detach().cpu().numpy())
                    gt_dict[fname[0]].append(mr_target[i].detach().cpu().numpy())

        test_ref_list = []
        for key in pred_dict.keys():
            test_ref_list.append((np.stack(pred_dict[key]), np.stack(gt_dict[key])))

        scores = aggregate_scores(test_ref_list, evaluator=Evaluator, json_author=self.config.author, json_task=self.config.name, json_name=self.config.name,
                                  json_output_file=self.elog.work_dir + "/{}_".format(self.config.author) + self.config.name + '.json')

        print("Scores:\n", scores)

    def segment_single_image(self, data):
        self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)
        self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")

        # a model must be present and loaded in here
        if self.config.model_dir == '':
            print('model_dir is empty, please provide directory to load checkpoint.')
        else:
            self.load_checkpoint(name=self.config.model_dir, save_types=("model",))

        self.elog.print("=====SEGMENT_SINGLE_IMAGE=====")
        self.model.eval()
        self.model.to(self.device)

        # Desired shape = [b, c, w, h]
        # split into even chunks (lets use size)
        with torch.no_grad():

            ######
            # When working entirely on CPU and in memory, the following lines replace the split/concat method
            # mr_data = data.float().to(self.device)
            # pred = self.model(mr_data)
            # pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)
            ######

            ######
            # for CUDA (also works on CPU) split into batches
            blocksize = self.config.batch_size

            # number_of_elements = round(data.shape[0]/blocksize+0.5)     # make blocks large enough to not lose any slices
            chunks = [data[i:i+blocksize, ::, ::, ::] for i in range(0, data.shape[0], blocksize)]
            pred_list = []
            for data_batch in chunks:
                mr_data = data_batch.float().to(self.device)
                pred_dict = self.model(mr_data)
                pred_list.append(pred_dict.cpu())

            pred = torch.Tensor(np.concatenate(pred_list))
            pred_argmax = torch.argmax(pred, dim=1, keepdim=True)

        # detach result and put it back to cpu so that we can work with, create a numpy array
        result = pred_argmax.short().detach().cpu().numpy()

        return result


================================================
FILE: experiments/UNetExperiment3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from datasets.three_dim.NumpyDataLoader import NumpyDataSet
from trixi.experiment.pytorchexperiment import PytorchExperiment

from networks.RecursiveUNet3D import UNet3D
from loss_functions.dice_loss import SoftDiceLoss, DC_and_CE_loss


class UNetExperiment3D(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """

    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                              keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                            keys=val_keys, mode="val", do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,
                                             keys=test_keys, mode="test", do_reshuffle=False)
        self.model = UNet3D(num_classes=3, in_channels=1)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
                                    'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, OrderedDict())

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('checkpoint_dir is empty, please provide directory to load checkpoint.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model",))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)

            loss = self.loss(pred, target.squeeze())
            # loss = self.ce_loss(pred, target.squeeze())
            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, loss))

                self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data[:,:,30].float(), name="data", normalize=True, scale_each=True, n_iter=epoch)
                self.clog.show_image_grid(target[:,:,30].float(), name="mask", title="Mask", n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(), dim=1, keepdim=True)[:,:,30], name="unt_argmax", title="Unet", n_iter=epoch)

            batch_counter += 1

    def validate(self, epoch):
        if epoch % 5 != 0:
            return
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)

                loss = self.loss(pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))

        self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)

        self.clog.show_image_grid(data[:,:,30].float(), name="data_val", normalize=True, scale_each=True, n_iter=epoch)
        self.clog.show_image_grid(target[:,:,30].float(), name="mask_val", title="Mask", n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu()[:,:,30], dim=1, keepdim=True), name="unt_argmax_val", title="Unet", n_iter=epoch)

    def test(self):
        # TODO
        print('TODO: Implement your test() method here')


================================================
FILE: experiments/__init__.py
================================================


================================================
FILE: loss_functions/ND_Crossentropy.py
================================================
import torch


class CrossentropyND(torch.nn.CrossEntropyLoss):
    """
    Network has to have NO NONLINEARITY!
    """
    def forward(self, inp, target):
        target = target.long()
        num_classes = inp.size()[1]

        i0 = 1
        i1 = 2

        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
            inp = inp.transpose(i0, i1)
            i0 += 1
            i1 += 1

        inp = inp.contiguous()
        inp = inp.view(-1, num_classes)

        target = target.view(-1,)

        return super(CrossentropyND, self).forward(inp, target)

================================================
FILE: loss_functions/__init__.py
================================================


================================================
FILE: loss_functions/dice_loss.py
================================================
import torch
import numpy as np
from loss_functions.ND_Crossentropy import CrossentropyND
from loss_functions.topk_loss import TopKLoss
from torch import nn


def softmax_helper(x):
    rpt = [1 for _ in range(len(x.size()))]
    rpt[1] = x.size(1)
    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)


def get_tp_fp_fn(net_output, gt, axes=None, mask=None):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask:
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn


def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp


def mean_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.mean(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.mean(int(ax))
    return inp


class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None):
        """
        hahaa no documentation for you today
        :param smooth:
        :param apply_nonlin:
        :param batch_dice:
        :param do_bg:
        :param smooth_in_nom:
        :param background_weight:
        :param rebalance_weights:
        """
        super(SoftDiceLoss, self).__init__()
        if not do_bg:
            assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy"
        self.rebalance_weights = rebalance_weights
        self.background_weight = background_weight
        if smooth_in_nom:
            self.smooth_in_nom = smooth
        else:
            self.smooth_in_nom = 0
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.y_onehot = None

    def forward(self, x, y):
        with torch.no_grad():
            y = y.long()
        shp_x = x.shape
        shp_y = y.shape
        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)
        if len(shp_x) != len(shp_y):
            y = y.view((shp_y[0], 1, *shp_y[1:]))
        # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively
        y_onehot = torch.zeros(shp_x)
        if x.device.type == "cuda":
            y_onehot = y_onehot.cuda(x.device.index)
        y_onehot.scatter_(1, y, 1)
        if not self.do_bg:
            x = x[:, 1:]
            y_onehot = y_onehot[:, 1:]
        if not self.batch_dice:
            if self.background_weight != 1 or (self.rebalance_weights is not None):
                raise NotImplementedError("nah son")
            l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom)
        else:
            l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom,
                                      background_weight=self.background_weight,
                                      rebalance_weights=self.rebalance_weights)
        return l


def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1):
    axes = tuple([0] + list(range(2, len(net_output.size()))))
    intersect = sum_tensor(net_output * gt, axes, keepdim=False)
    denom = sum_tensor(net_output + gt, axes, keepdim=False)
    weights = torch.ones(intersect.shape)
    weights[0] = background_weight
    if net_output.device.type == "cuda":
        weights = weights.cuda(net_output.device.index)
    result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean()
    return result


def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None):
    if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
        rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False
    axes = tuple([0] + list(range(2, len(net_output.size()))))
    tp = sum_tensor(net_output * gt, axes, keepdim=False)
    fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
    fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
    weights = torch.ones(tp.shape)
    weights[0] = background_weight
    if net_output.device.type == "cuda":
        weights = weights.cuda(net_output.device.index)
    if rebalance_weights is not None:
        rebalance_weights = torch.from_numpy(rebalance_weights).float()
        if net_output.device.type == "cuda":
            rebalance_weights = rebalance_weights.cuda(net_output.device.index)
        tp = tp * rebalance_weights
        fn = fn * rebalance_weights
    result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean()
    return result


def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):
    axes = tuple(range(2, len(net_output.size())))
    intersect = sum_tensor(net_output * gt, axes, keepdim=False)
    denom = sum_tensor(net_output + gt, axes, keepdim=False)
    result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean()
    return result


class MultipleOutputLoss(nn.Module):
    def __init__(self, loss, weight_factors=None):
        """
        use this if you have several outputs that should predict the same y
        :param loss:
        :param weight_factors:
        """
        super(MultipleOutputLoss, self).__init__()
        self.weight_factors = weight_factors
        self.loss = loss

    def forward(self, x, y):
        assert isinstance(x, (tuple, list)), "x must be either tuple or list"
        if self.weight_factors is None:
            weights = [1] * len(x)
        else:
            weights = self.weight_factors
        l = weights[0] * self.loss(x[0], y)
        for i in range(1, len(x)):
            l += weights[i] * self.loss(x[i], y)
        return l


class DC_and_CE_loss(nn.Module):
    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
        super(DC_and_CE_loss, self).__init__()
        self.aggregate = aggregate
        self.ce = CrossentropyND(**ce_kwargs)
        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)

    def forward(self, net_output, target):
        dc_loss = self.dc(net_output, target)
        ce_loss = self.ce(net_output, target)
        if self.aggregate == "sum":
            result = ce_loss + dc_loss
        else:
            raise NotImplementedError("nah son") # reserved for other stuff (later)
        return result


class DC_and_topk_loss(nn.Module):
    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
        super(DC_and_topk_loss, self).__init__()
        self.aggregate = aggregate
        self.ce = TopKLoss(**ce_kwargs)
        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)

    def forward(self, net_output, target):
        dc_loss = self.dc(net_output, target)
        ce_loss = self.ce(net_output, target)
        if self.aggregate == "sum":
            result = ce_loss + dc_loss
        else:
            raise NotImplementedError("nah son") # reserved for other stuff (later?)
        return result


class CrossentropyWithLossMask(nn.CrossEntropyLoss):
    def __init__(self, k=None):
        """
        This implementation ignores weight, ignore_index (use loss mask!) and reduction!
        :param k:
        """
        super(CrossentropyWithLossMask, self).__init__(weight=None, ignore_index=-100, reduction='none')
        self.k = k

    def forward(self, inp, target, loss_mask=None):
        target = target.long()
        inp = inp.float()
        if loss_mask is not None:
            loss_mask = loss_mask.float()
        num_classes = inp.size()[1]

        i0 = 1
        i1 = 2

        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
            inp = inp.transpose(i0, i1)
            i0 += 1
            i1 += 1

        if not inp.is_contiguous():
            inp = inp.contiguous()

        inp = inp.view(target.shape[0], -1, num_classes)

        target = target.view(target.shape[0], -1)
        if loss_mask is not None:
            loss_mask = loss_mask.view(target.shape[0], -1)

        if self.k is not None:
            if loss_mask is not None:
                num_sel = torch.stack(tuple([i.sum() / self.k for i in torch.unbind(loss_mask, 0)]), 0).long()
                loss = torch.stack(tuple([
                    torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()],
                               num_sel[i], sorted=False)[0].mean()
                    for i in range(target.shape[0])
                ])
                )
            else:
                num_sel = [np.prod(inp.shape[2:]) / self.k] * inp.shape[0]
                loss = torch.stack(tuple([
                    torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i]),
                               num_sel[i], sorted=False)[0].mean()
                    for i in range(target.shape[0])
                ])
                )
        else:
            if loss_mask is not None:
                loss = torch.stack(tuple([
                    super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()].mean()
                    for i in range(target.shape[0])
                ])
                )
            else:
                loss = torch.stack(tuple([
                    super(CrossentropyWithLossMask, self).forward(inp[i], target[i]).mean()
                    for i in range(target.shape[0])
                ])
                )

        loss = loss.mean()
        return loss


================================================
FILE: loss_functions/topk_loss.py
================================================
import numpy as np
import torch
from loss_functions.ND_Crossentropy import CrossentropyND


class TopKLoss(CrossentropyND):
    """
    Network has to have NO LINEARITY!
    """

    def __init__(self, weight=None, ignore_index=-100, k=10):
        self.k = k
        super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)

    def forward(self, inp, target):
        target = target[:, 0].long()
        res = super(TopKLoss, self).forward(inp, target)
        num_voxels = np.prod(res.shape)
        res, _ = torch.topk(res.view((-1,)), int(num_voxels // self.k), sorted=False)
        return res.mean()

================================================
FILE: networks/RecursiveUNet.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defines the Unet.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck

# recursive implementation of Unet
import torch

from torch import nn


class UNet(nn.Module):
    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):
        # norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNet, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
        for i in range(1, num_downs):
            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),
                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
                                             outermost=True)

        self.model = unet_block

    def forward(self, x):
        return self.model(x)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        # downconv
        pool = nn.MaxPool2d(2, stride=2)
        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)

        # upconv
        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)

        if outermost:
            final = nn.Conv2d(out_channels, num_classes, kernel_size=1)
            down = [conv1, conv2]
            up = [conv3, conv4, final]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels,
                                        kernel_size=2, stride=2)
            model = [pool, conv1, conv2, upconv]
        else:
            upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)

            down = [pool, conv1, conv2]
            up = [conv3, conv4, upconv]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(inplace=True))
        return layer

    @staticmethod
    def expand(in_channels, out_channels, kernel_size=3):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        return layer

    @staticmethod
    def center_crop(layer, target_width, target_height):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_width) // 2
        xy2 = (layer_height - target_height) // 2
        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])
            return torch.cat([x, crop], 1)


================================================
FILE: networks/RecursiveUNet3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defines the Unet.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck

# recursive implementation of Unet
import torch

from torch import nn


class UNet3D(nn.Module):
    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=3, norm_layer=nn.InstanceNorm3d):
        # norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UNet3D, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
        for i in range(1, num_downs):
            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),
                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
                                             outermost=True)

        self.model = unet_block

    def forward(self, x):
        return self.model(x)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm3d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        # downconv
        pool = nn.MaxPool3d(2, stride=2)
        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)

        # upconv
        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)

        if outermost:
            final = nn.Conv3d(out_channels, num_classes, kernel_size=1)
            down = [conv1, conv2]
            up = [conv3, conv4, final]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose3d(in_channels*2, in_channels,
                                        kernel_size=2, stride=2)
            model = [pool, conv1, conv2, upconv]
        else:
            upconv = nn.ConvTranspose3d(in_channels*2, in_channels, kernel_size=2, stride=2)

            down = [pool, conv1, conv2]
            up = [conv3, conv4, upconv]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm3d):
        layer = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(inplace=True))
        return layer

    @staticmethod
    def expand(in_channels, out_channels, kernel_size=3):
        layer = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        return layer

    @staticmethod
    def center_crop(layer, target_depth, target_width, target_height):
        batch_size, n_channels, layer_depth, layer_width, layer_height = layer.size()
        xy0 = (layer_depth - target_depth) // 2
        xy1 = (layer_width - target_width) // 2
        xy2 = (layer_height - target_height) // 2
        return layer[:, :, xy0:(xy0 + target_depth), xy1:(xy1 + target_width), xy2:(xy2 + target_height)]

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3], x.size()[4])
            return torch.cat([x, crop], 1)


================================================
FILE: networks/UNET.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, num_classes, in_channels=1, initial_filter_size=64, kernel_size=3, do_instancenorm=True):
        super().__init__()

        self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm)
        self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, instancenorm=do_instancenorm)
        self.pool = nn.MaxPool2d(2, stride=2)

        self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)
        self.contr_2_2 = self.contract(initial_filter_size*2, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)
        # self.pool2 = nn.MaxPool2d(2, stride=2)

        self.contr_3_1 = self.contract(initial_filter_size*2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)
        self.contr_3_2 = self.contract(initial_filter_size*2**2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)
        # self.pool3 = nn.MaxPool2d(2, stride=2)

        self.contr_4_1 = self.contract(initial_filter_size*2**2, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)
        self.contr_4_2 = self.contract(initial_filter_size*2**3, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)
        # self.pool4 = nn.MaxPool2d(2, stride=2)

        self.center = nn.Sequential(
            nn.Conv2d(initial_filter_size*2**3, initial_filter_size*2**4, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(initial_filter_size*2**4, initial_filter_size*2**4, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(initial_filter_size*2**4, initial_filter_size*2**3, 2, stride=2),
            nn.ReLU(inplace=True),
        )

        self.expand_4_1 = self.expand(initial_filter_size*2**4, initial_filter_size*2**3)
        self.expand_4_2 = self.expand(initial_filter_size*2**3, initial_filter_size*2**3)
        self.upscale4 = nn.ConvTranspose2d(initial_filter_size*2**3, initial_filter_size*2**2, kernel_size=2, stride=2)

        self.expand_3_1 = self.expand(initial_filter_size*2**3, initial_filter_size*2**2)
        self.expand_3_2 = self.expand(initial_filter_size*2**2, initial_filter_size*2**2)
        self.upscale3 = nn.ConvTranspose2d(initial_filter_size*2**2, initial_filter_size*2, 2, stride=2)

        self.expand_2_1 = self.expand(initial_filter_size*2**2, initial_filter_size*2)
        self.expand_2_2 = self.expand(initial_filter_size*2, initial_filter_size*2)
        self.upscale2 = nn.ConvTranspose2d(initial_filter_size*2, initial_filter_size, 2, stride=2)

        self.expand_1_1 = self.expand(initial_filter_size*2, initial_filter_size)
        self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size)
        # Output layer for segmentation
        self.final = nn.Conv2d(initial_filter_size, num_classes, kernel_size=1)  # kernel size for final layer = 1, see paper

        self.softmax = torch.nn.Softmax2d()

        # Output layer for "autoencoder-mode"
        self.output_reconstruction_map = nn.Conv2d(initial_filter_size, out_channels=1, kernel_size=1)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, instancenorm=True):
        if instancenorm:
            layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(inplace=True))
        else:
            layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
                nn.LeakyReLU(inplace=True))
        return layer

    @staticmethod
    def expand(in_channels, out_channels, kernel_size=3):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
            nn.LeakyReLU(inplace=True),
            )
        return layer

    @staticmethod
    def center_crop(layer, target_width, target_height):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_width) // 2
        xy2 = (layer_height - target_height) // 2
        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]

    def forward(self, x, enable_concat=True, print_layer_shapes=False):
        concat_weight = 1
        if not enable_concat:
            concat_weight = 0

        contr_1 = self.contr_1_2(self.contr_1_1(x))
        pool = self.pool(contr_1)

        contr_2 = self.contr_2_2(self.contr_2_1(pool))
        pool = self.pool(contr_2)

        contr_3 = self.contr_3_2(self.contr_3_1(pool))
        pool = self.pool(contr_3)

        contr_4 = self.contr_4_2(self.contr_4_1(pool))
        pool = self.pool(contr_4)

        center = self.center(pool)

        crop = self.center_crop(contr_4, center.size()[2], center.size()[3])
        concat = torch.cat([center, crop*concat_weight], 1)

        expand = self.expand_4_2(self.expand_4_1(concat))
        upscale = self.upscale4(expand)

        crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3])
        concat = torch.cat([upscale, crop*concat_weight], 1)

        expand = self.expand_3_2(self.expand_3_1(concat))
        upscale = self.upscale3(expand)

        crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3])
        concat = torch.cat([upscale, crop*concat_weight], 1)

        expand = self.expand_2_2(self.expand_2_1(concat))
        upscale = self.upscale2(expand)

        crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3])
        concat = torch.cat([upscale, crop*concat_weight], 1)

        expand = self.expand_1_2(self.expand_1_1(concat))

        if enable_concat:
            output = self.final(expand)
        if not enable_concat:
            output = self.output_reconstruction_map(expand)

        return output


================================================
FILE: requirements.txt
================================================
googledrivedownloader==0.4
MedPy==0.4.0
torch==1.3.1
torchfile==0.1.0
trixi==0.1.2.1
batchgenerators==0.19.3

# Workaround for scipy issues
scipy==1.1.0

# Workaround for slackclient version issues
slackclient==2.0.0

# Fix compatibility issues
torchvision==0.4.2



================================================
FILE: run_preprocessing.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from configs.Config_unet import get_config
from datasets.example_dataset.create_splits import create_splits
from datasets.utils import download_dataset
from datasets.example_dataset.preprocessing import preprocess_data

if __name__ == "__main__":
    c = get_config()

    download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)

    print('Preprocessing data. [STARTED]')
    preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name))
    create_splits(output_dir=c.split_dir, image_dir=c.data_dir)
    print('Preprocessing data. [DONE]')


================================================
FILE: run_train_pipeline.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os.path import exists
from configs.Config_unet import get_config
from datasets.example_dataset.create_splits import create_splits
from datasets.utils import download_dataset
from datasets.example_dataset.preprocessing import preprocess_data
from experiments.UNetExperiment import UNetExperiment


if __name__ == "__main__":
    c = get_config()

    # print("Executing: EPOCHS = {} / LEARNING RATE = {}".format(c.n_epochs, c.learning_rate))

    download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)

    if not exists(os.path.join(os.path.join(c.data_root_dir, c.dataset_name), 'preprocessed')):
        print('Preprocessing data. [STARTED]')
        preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name), y_shape=c.patch_size, z_shape=c.patch_size)
        create_splits(output_dir=c.split_dir, image_dir=c.data_dir)
        print('Preprocessing data. [DONE]')
    else:
        print('The data has already been preprocessed. It will not be preprocessed again. Delete the folder to enforce it.')

    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
                         # visdomlogger_kwargs={"auto_start": c.start_visdom},
                         loggers={
                             "visdom": ("visdom", {"auto_start": c.start_visdom})
                         }
                         )

    exp.run()
    exp.run_test(setup=False)


================================================
FILE: runner.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from configs.Config_unet_spleen import get_config
import subprocess

if __name__ == "__main__":
    c = get_config()
    n_epochs = c.n_epochs
    learning_rate = c.learning_rate
    step = 0

    while True:
        result = subprocess.run(['python', 'run_train_pipeline.py',
                                 '--n_epochs', '{}'.format(n_epochs),
                                 '--learning_rate', '{}'.format(learning_rate)])

        if divmod(step, 2)[1] == 0:
            n_epochs = n_epochs + 20
        else:
            learning_rate = learning_rate / 2
        step += 1


================================================
FILE: segment_a_spleen.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

from medpy.io import save

from configs.Config_unet_spleen import get_config
from datasets.spleen.preprocessing import preprocess_single_file, postprocess_single_image
from experiments.UNetExperiment import UNetExperiment


def save_single_image(image, image_header, filename):
    # medpy.io.save
    save(image, filename, image_header)
    print('> Resulting Image stored as {}'.format(filename))


if __name__ == "__main__":
    c = get_config()

    if len(sys.argv) == 1:
        print("USAGE:\n\npython {} imagefilename [model_checkpoint [shapesize]]\n\n"
              "  imagefilename - a filename that stores a nii.gz formatted file.\n"
              "  model_checkpoint - a checkpoint filename to reload\n"
              "  shapesize - optional value that defines "
              "the size of the shape, default is 64 (not yet used).".format(sys.argv[0]))
        filename = "data/Task09_Spleen/imagesTs/spleen_15.nii.gz"
    else:
        filename = sys.argv[1]

    print("Loading and processing file {}".format(filename))

    if len(sys.argv) > 2:
        c.checkpoint_dir = sys.argv[2]
        c.do_load_checkpoint = True

    print("Loading model from checkpoint {}".format(c.model_dir))
    if len(c.model_dir) == 0 or not os.path.isdir(os.path.split(c.model_dir)[0]):
        print("ERROR /!\\: No checkpoint dir is set, please provide in Config file.")
        exit()

    shapesize = 64
    if len(sys.argv) > 3:
        shapesize = int(sys.argv[3])

    # Get the header in order to preserve voxel dimensions to store the segmented image later on
    print('Preprocessing data.')
    data, header = preprocess_single_file(filename, y_shape=shapesize, z_shape=shapesize)

    print('Setting up model and start segmentation.')
    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals()
                         )

    result = exp.segment_single_image(data)

    print('Postprocessing data.')
    result = postprocess_single_image(result)

    pathname, fname = os.path.split(filename)
    destination_filename = pathname+"/segmented_"+fname
    print('Saving file to disk: {}'.format(destination_filename))
    save_single_image(result, header, destination_filename)


================================================
FILE: train.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib
matplotlib.use('Agg')

from configs.Config_unet import get_config
from experiments.UNetExperiment import UNetExperiment

if __name__ == "__main__":
    c = get_config()

    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
                         # visdomlogger_kwargs={"auto_start": c.start_visdom},
                         loggers={
                             "visdom": ("visdom", {"auto_start": c.start_visdom}),
                             # "tb": ("tensorboard"),
                             # "slack": ("slack", {"token": "XXXXXXXX",
                             #                     "user_email": "x"})
                         }
                         )

    exp.run()
    exp.run_test(setup=False)


================================================
FILE: train3D.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from configs.Config_unet import get_config
from experiments.UNetExperiment3D import UNetExperiment3D

if __name__ == "__main__":
    c = get_config()

    exp = UNetExperiment3D(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
                         # visdomlogger_kwargs={"auto_start": c.start_visdom},
                         loggers={
                             "visdom": ("visdom", {"auto_start": c.start_visdom}),
                             # "tb": ("tensorboard"),
                             # "slack": ("slack", {"token": "XXXXXXXX",
                             #                     "user_email": "x"})
                         }
                         )
    exp.run()
    exp.run_test(setup=False)


================================================
FILE: utilities/__init__.py
================================================


================================================
FILE: utilities/file_and_folder_operations.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os


def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
    if join:
        l = os.path.join
    else:
        l = lambda x, y: y
    res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
            and (prefix is None or i.startswith(prefix))
            and (suffix is None or i.endswith(suffix))]
    if sort:
        res.sort()
    return res


def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
    if join:
        l = os.path.join
    else:
        l = lambda x, y: y
    res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
            and (prefix is None or i.startswith(prefix))
            and (suffix is None or i.endswith(suffix))]
    if sort:
        res.sort()
    return res


def maybe_mkdir_p(directory):
    splits = directory.split("/")[1:]
    for i in range(0, len(splits)):
        if not os.path.isdir(os.path.join("/", *splits[:i+1])):
            os.mkdir(os.path.join("/", *splits[:i+1]))
Download .txt
gitextract_9r9wvxz8/

├── .gitignore
├── LICENSE
├── Readme.md
├── __init__.py
├── configs/
│   ├── Config_unet.py
│   ├── Config_unet_spleen.py
│   └── __init__.py
├── datasets/
│   ├── __init__.py
│   ├── data_loader.py
│   ├── example_dataset/
│   │   ├── __init__.py
│   │   ├── create_splits.py
│   │   └── preprocessing.py
│   ├── spleen/
│   │   ├── __init__.py
│   │   ├── create_splits.py
│   │   └── preprocessing.py
│   ├── three_dim/
│   │   ├── NumpyDataLoader.py
│   │   ├── __init__.py
│   │   └── data_augmentation.py
│   ├── two_dim/
│   │   ├── NumpyDataLoader.py
│   │   ├── __init__.py
│   │   └── data_augmentation.py
│   └── utils.py
├── evaluation/
│   ├── __init__.py
│   ├── evaluator.py
│   ├── metrics.py
│   └── readme.md
├── experiments/
│   ├── UNetExperiment.py
│   ├── UNetExperiment3D.py
│   └── __init__.py
├── loss_functions/
│   ├── ND_Crossentropy.py
│   ├── __init__.py
│   ├── dice_loss.py
│   └── topk_loss.py
├── networks/
│   ├── RecursiveUNet.py
│   ├── RecursiveUNet3D.py
│   └── UNET.py
├── requirements.txt
├── run_preprocessing.py
├── run_train_pipeline.py
├── runner.py
├── segment_a_spleen.py
├── train.py
├── train3D.py
└── utilities/
    ├── __init__.py
    └── file_and_folder_operations.py
Download .txt
SYMBOL INDEX (169 symbols across 24 files)

FILE: configs/Config_unet.py
  function get_config (line 23) | def get_config():

FILE: configs/Config_unet_spleen.py
  function get_config (line 23) | def get_config():

FILE: datasets/data_loader.py
  class WrappedDataset (line 22) | class WrappedDataset(Dataset):
    method __init__ (line 23) | def __init__(self, dataset, transform):
    method __getitem__ (line 31) | def __getitem__(self, index):
    method __len__ (line 40) | def __len__(self):
  class MultiThreadedDataLoader (line 44) | class MultiThreadedDataLoader(object):
    method __init__ (line 45) | def __init__(self, data_loader, transform, num_processes, **kwargs):
    method get_worker_init_fn (line 57) | def get_worker_init_fn(self):
    method __iter__ (line 63) | def __iter__(self):
    method __next__ (line 68) | def __next__(self):
    method renew (line 73) | def renew(self):
    method restart (line 79) | def restart(self):
    method kill_iterator (line 83) | def kill_iterator(self):

FILE: datasets/example_dataset/create_splits.py
  function create_splits (line 25) | def create_splits(output_dir, image_dir):
  function splits_sanity_check (line 70) | def splits_sanity_check(path):

FILE: datasets/example_dataset/preprocessing.py
  function preprocess_data (line 29) | def preprocess_data(root_dir, y_shape=64, z_shape=64):
  function preprocess_single_file (line 74) | def preprocess_single_file(image_file):
  function postprocess_single_image (line 83) | def postprocess_single_image(image):

FILE: datasets/spleen/create_splits.py
  function create_splits (line 25) | def create_splits(output_dir, image_dir):

FILE: datasets/spleen/preprocessing.py
  function preprocess_data (line 28) | def preprocess_data(root_dir, y_shape=64, z_shape=64):
  function preprocess_single_file (line 68) | def preprocess_single_file(image_file):
  function postprocess_single_image (line 87) | def postprocess_single_image(image):

FILE: datasets/three_dim/NumpyDataLoader.py
  function load_dataset (line 29) | def load_dataset(base_dir, pattern='*.npy', keys=None):
  class NumpyDataSet (line 52) | class NumpyDataSet(object):
    method __init__ (line 56) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
    method __len__ (line 73) | def __len__(self):
    method __iter__ (line 76) | def __iter__(self):
    method __next__ (line 82) | def __next__(self):
  class NumpyDataLoader (line 86) | class NumpyDataLoader(SlimDataLoaderBase):
    method __init__ (line 87) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
    method reshuffle (line 112) | def reshuffle(self):
    method generate_train_batch (line 117) | def generate_train_batch(self):
    method __len__ (line 121) | def __len__(self):
    method __getitem__ (line 125) | def __getitem__(self, item):
    method get_data_from_array (line 150) | def get_data_from_array(self, open_array):

FILE: datasets/three_dim/data_augmentation.py
  function get_transforms (line 24) | def get_transforms(mode="train", target_size=128):

FILE: datasets/two_dim/NumpyDataLoader.py
  function load_dataset (line 29) | def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):
  class NumpyDataSet (line 52) | class NumpyDataSet(object):
    method __init__ (line 56) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
    method __len__ (line 73) | def __len__(self):
    method __iter__ (line 76) | def __iter__(self):
    method __next__ (line 82) | def __next__(self):
  class NumpyDataLoader (line 86) | class NumpyDataLoader(SlimDataLoaderBase):
    method __init__ (line 87) | def __init__(self, base_dir, mode="train", batch_size=16, num_batches=...
    method reshuffle (line 112) | def reshuffle(self):
    method generate_train_batch (line 117) | def generate_train_batch(self):
    method __len__ (line 121) | def __len__(self):
    method __getitem__ (line 125) | def __getitem__(self, item):
    method get_data_from_array (line 150) | def get_data_from_array(self, open_array):

FILE: datasets/two_dim/data_augmentation.py
  function get_transforms (line 26) | def get_transforms(mode="train", target_size=128):

FILE: datasets/utils.py
  function download_dataset (line 25) | def download_dataset(dest_path, dataset, id=''):

FILE: evaluation/evaluator.py
  class Evaluator (line 15) | class Evaluator:
    method __init__ (line 45) | def __init__(self,
    method set_test (line 84) | def set_test(self, test):
    method set_reference (line 89) | def set_reference(self, reference):
    method set_labels (line 94) | def set_labels(self, labels):
    method construct_labels (line 110) | def construct_labels(self):
    method set_metrics (line 122) | def set_metrics(self, metrics):
    method add_metric (line 132) | def add_metric(self, metric):
    method evaluate (line 137) | def evaluate(self, test=None, reference=None, advanced=False, **metric...
    method to_dict (line 210) | def to_dict(self):
    method to_array (line 216) | def to_array(self):
    method to_pandas (line 237) | def to_pandas(self):
  class NiftiEvaluator (line 252) | class NiftiEvaluator(Evaluator):
    method __init__ (line 254) | def __init__(self, *args, **kwargs):
    method set_test (line 260) | def set_test(self, test):
    method set_reference (line 270) | def set_reference(self, reference):
    method evaluate (line 280) | def evaluate(self, test=None, reference=None, voxel_spacing=None, **me...
  function aggregate_scores (line 289) | def aggregate_scores(test_ref_pairs,
  function aggregate_scores_for_experiment (line 376) | def aggregate_scores_for_experiment(score_file,

FILE: evaluation/metrics.py
  function assert_shape (line 5) | def assert_shape(test, reference):
  class ConfusionMatrix (line 11) | class ConfusionMatrix:
    method __init__ (line 13) | def __init__(self, test=None, reference=None):
    method set_test (line 27) | def set_test(self, test):
    method set_reference (line 32) | def set_reference(self, reference):
    method reset (line 37) | def reset(self):
    method compute (line 49) | def compute(self):
    method get_matrix (line 66) | def get_matrix(self):
    method get_size (line 75) | def get_size(self):
    method get_existence (line 81) | def get_existence(self):
  function dice (line 91) | def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonex...
  function jaccard (line 109) | def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_no...
  function precision (line 127) | def precision(test=None, reference=None, confusion_matrix=None, nan_for_...
  function sensitivity (line 145) | def sensitivity(test=None, reference=None, confusion_matrix=None, nan_fo...
  function recall (line 163) | def recall(test=None, reference=None, confusion_matrix=None, nan_for_non...
  function specificity (line 169) | def specificity(test=None, reference=None, confusion_matrix=None, nan_fo...
  function accuracy (line 187) | def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
  function fscore (line 198) | def fscore(test=None, reference=None, confusion_matrix=None, nan_for_non...
  function false_positive_rate (line 208) | def false_positive_rate(test=None, reference=None, confusion_matrix=None...
  function false_omission_rate (line 214) | def false_omission_rate(test=None, reference=None, confusion_matrix=None...
  function false_negative_rate (line 232) | def false_negative_rate(test=None, reference=None, confusion_matrix=None...
  function true_negative_rate (line 238) | def true_negative_rate(test=None, reference=None, confusion_matrix=None,...
  function false_discovery_rate (line 244) | def false_discovery_rate(test=None, reference=None, confusion_matrix=Non...
  function negative_predictive_value (line 250) | def negative_predictive_value(test=None, reference=None, confusion_matri...
  function total_positives_test (line 256) | def total_positives_test(test=None, reference=None, confusion_matrix=Non...
  function total_negatives_test (line 267) | def total_negatives_test(test=None, reference=None, confusion_matrix=Non...
  function total_positives_reference (line 278) | def total_positives_reference(test=None, reference=None, confusion_matri...
  function total_negatives_reference (line 289) | def total_negatives_reference(test=None, reference=None, confusion_matri...
  function hausdorff_distance (line 300) | def hausdorff_distance(test=None, reference=None, confusion_matrix=None,...
  function hausdorff_distance_95 (line 318) | def hausdorff_distance_95(test=None, reference=None, confusion_matrix=No...
  function avg_surface_distance (line 336) | def avg_surface_distance(test=None, reference=None, confusion_matrix=Non...
  function avg_surface_distance_symmetric (line 354) | def avg_surface_distance_symmetric(test=None, reference=None, confusion_...

FILE: experiments/UNetExperiment.py
  class UNetExperiment (line 34) | class UNetExperiment(PytorchExperiment):
    method setup (line 53) | def setup(self):
    method train (line 92) | def train(self, epoch):
    method validate (line 131) | def validate(self, epoch):
    method test (line 160) | def test(self):
    method segment_single_image (line 197) | def segment_single_image(self, data):

FILE: experiments/UNetExperiment3D.py
  class UNetExperiment3D (line 34) | class UNetExperiment3D(PytorchExperiment):
    method setup (line 53) | def setup(self):
    method train (line 92) | def train(self, epoch):
    method validate (line 126) | def validate(self, epoch):
    method test (line 156) | def test(self):

FILE: loss_functions/ND_Crossentropy.py
  class CrossentropyND (line 4) | class CrossentropyND(torch.nn.CrossEntropyLoss):
    method forward (line 8) | def forward(self, inp, target):

FILE: loss_functions/dice_loss.py
  function softmax_helper (line 8) | def softmax_helper(x):
  function get_tp_fp_fn (line 16) | def get_tp_fp_fn(net_output, gt, axes=None, mask=None):
  function sum_tensor (line 63) | def sum_tensor(inp, axes, keepdim=False):
  function mean_tensor (line 74) | def mean_tensor(inp, axes, keepdim=False):
  class SoftDiceLoss (line 85) | class SoftDiceLoss(nn.Module):
    method __init__ (line 86) | def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_...
    method forward (line 112) | def forward(self, x, y):
  function soft_dice_per_batch (line 140) | def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., bac...
  function soft_dice_per_batch_2 (line 152) | def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., b...
  function soft_dice (line 173) | def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):
  class MultipleOutputLoss (line 181) | class MultipleOutputLoss(nn.Module):
    method __init__ (line 182) | def __init__(self, loss, weight_factors=None):
    method forward (line 192) | def forward(self, x, y):
  class DC_and_CE_loss (line 204) | class DC_and_CE_loss(nn.Module):
    method __init__ (line 205) | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
    method forward (line 211) | def forward(self, net_output, target):
  class DC_and_topk_loss (line 221) | class DC_and_topk_loss(nn.Module):
    method __init__ (line 222) | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
    method forward (line 228) | def forward(self, net_output, target):
  class CrossentropyWithLossMask (line 238) | class CrossentropyWithLossMask(nn.CrossEntropyLoss):
    method __init__ (line 239) | def __init__(self, k=None):
    method forward (line 247) | def forward(self, inp, target, loss_mask=None):

FILE: loss_functions/topk_loss.py
  class TopKLoss (line 6) | class TopKLoss(CrossentropyND):
    method __init__ (line 11) | def __init__(self, weight=None, ignore_index=-100, k=10):
    method forward (line 15) | def forward(self, inp, target):

FILE: networks/RecursiveUNet.py
  class UNet (line 28) | class UNet(nn.Module):
    method __init__ (line 29) | def __init__(self, num_classes=3, in_channels=1, initial_filter_size=6...
    method forward (line 46) | def forward(self, x):
  class UnetSkipConnectionBlock (line 53) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 54) | def __init__(self, in_channels=None, out_channels=None, num_classes=1,...
    method contract (line 90) | def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.I...
    method expand (line 98) | def expand(in_channels, out_channels, kernel_size=3):
    method center_crop (line 106) | def center_crop(layer, target_width, target_height):
    method forward (line 112) | def forward(self, x):

FILE: networks/RecursiveUNet3D.py
  class UNet3D (line 28) | class UNet3D(nn.Module):
    method __init__ (line 29) | def __init__(self, num_classes=3, in_channels=1, initial_filter_size=6...
    method forward (line 46) | def forward(self, x):
  class UnetSkipConnectionBlock (line 53) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 54) | def __init__(self, in_channels=None, out_channels=None, num_classes=1,...
    method contract (line 90) | def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.I...
    method expand (line 98) | def expand(in_channels, out_channels, kernel_size=3):
    method center_crop (line 106) | def center_crop(layer, target_depth, target_width, target_height):
    method forward (line 113) | def forward(self, x):

FILE: networks/UNET.py
  class UNet (line 22) | class UNet(nn.Module):
    method __init__ (line 24) | def __init__(self, num_classes, in_channels=1, initial_filter_size=64,...
    method contract (line 75) | def contract(in_channels, out_channels, kernel_size=3, instancenorm=Tr...
    method expand (line 88) | def expand(in_channels, out_channels, kernel_size=3):
    method center_crop (line 96) | def center_crop(layer, target_width, target_height):
    method forward (line 102) | def forward(self, x, enable_concat=True, print_layer_shapes=False):

FILE: segment_a_spleen.py
  function save_single_image (line 28) | def save_single_image(image, image_header, filename):

FILE: utilities/file_and_folder_operations.py
  function subdirs (line 21) | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
  function subfiles (line 34) | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
  function maybe_mkdir_p (line 47) | def maybe_mkdir_p(directory):
Condensed preview — 45 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (161K chars).
[
  {
    "path": ".gitignore",
    "chars": 78,
    "preview": ".idea\n*.pyc\n.DS_Store\n*.egg-info\n.pytest_cache/*\n\ndata\noutput_experiment\nvenv\n"
  },
  {
    "path": "LICENSE",
    "chars": 11356,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "Readme.md",
    "chars": 11774,
    "preview": "# Basic U-Net example by MIC@DKFZ\nCopyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "configs/Config_unet.py",
    "chars": 3047,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "configs/Config_unet_spleen.py",
    "chars": 3072,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "configs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/data_loader.py",
    "chars": 2897,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/example_dataset/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/example_dataset/create_splits.py",
    "chars": 3543,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/example_dataset/preprocessing.py",
    "chars": 2802,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/spleen/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/spleen/create_splits.py",
    "chars": 1958,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/spleen/preprocessing.py",
    "chars": 3024,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/three_dim/NumpyDataLoader.py",
    "chars": 5585,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/three_dim/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/three_dim/data_augmentation.py",
    "chars": 2441,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/two_dim/NumpyDataLoader.py",
    "chars": 5899,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/two_dim/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/two_dim/data_augmentation.py",
    "chars": 2555,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "datasets/utils.py",
    "chars": 1486,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "evaluation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "evaluation/evaluator.py",
    "chars": 14632,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport collections\nimport inspect\nimport json\nimport hashlib\nfrom dateti"
  },
  {
    "path": "evaluation/metrics.py",
    "chars": 12329,
    "preview": "import numpy as np\nfrom medpy import metric\n\n\ndef assert_shape(test, reference):\n\n    assert test.shape == reference.sha"
  },
  {
    "path": "evaluation/readme.md",
    "chars": 4881,
    "preview": "# Evaluation Suite\n\n### Metrics\n\nAll metrics can be used either by passing test and reference segmentations as\nparameter"
  },
  {
    "path": "experiments/UNetExperiment.py",
    "chars": 10517,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "experiments/UNetExperiment3D.py",
    "chars": 6672,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "experiments/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "loss_functions/ND_Crossentropy.py",
    "chars": 609,
    "preview": "import torch\n\n\nclass CrossentropyND(torch.nn.CrossEntropyLoss):\n    \"\"\"\n    Network has to have NO NONLINEARITY!\n    \"\"\""
  },
  {
    "path": "loss_functions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "loss_functions/dice_loss.py",
    "chars": 11290,
    "preview": "import torch\nimport numpy as np\nfrom loss_functions.ND_Crossentropy import CrossentropyND\nfrom loss_functions.topk_loss "
  },
  {
    "path": "loss_functions/topk_loss.py",
    "chars": 624,
    "preview": "import numpy as np\nimport torch\nfrom loss_functions.ND_Crossentropy import CrossentropyND\n\n\nclass TopKLoss(CrossentropyN"
  },
  {
    "path": "networks/RecursiveUNet.py",
    "chars": 5265,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "networks/RecursiveUNet3D.py",
    "chars": 5383,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "networks/UNET.py",
    "chars": 6825,
    "preview": "#!/usr/bin/env python\r\n# -*- coding: utf-8 -*-\r\n#\r\n# Copyright 2017 Division of Medical Image Computing, German Cancer R"
  },
  {
    "path": "requirements.txt",
    "chars": 265,
    "preview": "googledrivedownloader==0.4\nMedPy==0.4.0\ntorch==1.3.1\ntorchfile==0.1.0\ntrixi==0.1.2.1\nbatchgenerators==0.19.3\n\n# Workarou"
  },
  {
    "path": "run_preprocessing.py",
    "chars": 1284,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "run_train_pipeline.py",
    "chars": 2208,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "runner.py",
    "chars": 1266,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2019 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "segment_a_spleen.py",
    "chars": 2995,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "train.py",
    "chars": 1527,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "train3D.py",
    "chars": 1491,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  },
  {
    "path": "utilities/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utilities/file_and_folder_operations.py",
    "chars": 1723,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Rese"
  }
]

About this extraction

This page contains the full source code of the MIC-DKFZ/basic_unet_example GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 45 files (149.7 KB), approximately 36.2k tokens, and a symbol index with 169 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!