[
  {
    "path": ".gitignore",
    "content": ".idea\n*.pyc\n.DS_Store\n*.egg-info\n.pytest_cache/*\n\ndata\noutput_experiment\nvenv\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright {yyyy} {name of copyright owner}\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "Readme.md",
    "content": "﻿# Basic U-Net example by MIC@DKFZ\nCopyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). Please make sure that your usage of this code is in compliance with the code license:\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/MIC-DKFZ/basic_unet_example/blob/master/LICENSE)\n\nThis python code is an example project of how to use a U-Net [1] for segmentation on medical images using PyTorch (https://www.pytorch.org).\nIt was developed at the Division of Medical Image Computing at the German Cancer Research Center (DKFZ).\nIt is also an example of how to use our other python packages batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) and \nTrixi (https://github.com/MIC-DKFZ/trixi) [2] to suit all our deep learning data augmentation needs.\n\nIf you have any questions or issues or you encounter a bug, feel free to contact us, open a GitHub issue or ask the community on Gitter:\n[![Gitter](https://badges.gitter.im/basic-Unet/community.svg)](https://gitter.im/basic-Unet/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)\n\n> **WARNING**: This repo was implemented and tested on Linux. We highly recommend using it within a Linux environment. If you use Windows you might experience some issues (see\n> section \"Errors and how to handle them\")\n\n## How to set it up\nThe example is very easy to use. Just create a new virtual environment in python and install the requirements. \nThis example requires python3. It was implemented with python 3.5. \n\n> **WARNING**: The newest supported version is python 3.7.9. For newer python versions there are some requirements that are not available in the needed version.\n```\npip3 install -r requirements.txt\n```\n\nIn this example code, we show how to use visdom for live visualization. See the Trixi documentation for more details or information about other tools like tensorboard.\nAfter setting up the virtual environment you have to start visdom once so it can download some needed files. You only\nhave to do that once. You can stop the visdom server after a few seconds when it finished downloading the files.\n```\npython3 -m visdom.server\n```\n\nYou can edit the paths for data storage and logging in the config file. By default, everything is stored in your working directory.\n\n\n## How to use it\nTo start the training simply run \n```\npython3 run_train_pipeline.py\n```\n\nThis will download the Hippocampus dataset from the medical segmentation decathlon (http://medicaldecathlon.com),\nextract and preprocess it and then start the training. The preprocessing loads the images (imagesTr) and the corresponding labels (labelsTr), performs some normalization and padding operations and saves the data as NPY files. The available images are then split into `train`, `validation` and `test` sets.\nThe splits are saved to a `splits.pkl` file. The images in `imagesTs` are  not used in the example, because they are the test set for the medical segmentation decathlon and\n therefore no ground truth is provided.\n\nIf you run the pipeline again, the dataset will not be downloaded, extracted or preprocessed again. To enforce it, just delete the folder.\n\nThe training process will automatically be visualized using trixi/visdom. After starting the training you navigate in your browser to the port which is printed by the training script. Then you should see your loss curve and so on.\n\nBy default, a 2-dimensional U-Net is used. The example also comes with a 3-D version of the network (Özgün Cicek et al.).\nTo use the 3-D version, simple use\n```\npython train3D.py\n```\n\n> **WARNING**: The 3-D version is not yet tested thoroughly. Use it with caution!\n\n## How to use it for your own data\nThis description is work in progress. If you use this repo for your own data please share your experience, so we can update this part.\n\n### Config\n\nThe included `Config_unet.py` is an example config file. You have to adapt this to fit your local environment, e.g., if you run out of CUDA memory, try to reduce `batch_size` or\n `patch_size`. All the other parameters should be self-explanatory or described directly in the code comments. \n\nChoose the `#Train parameters` to fit both, your data and your workstation. \nWith `fold` you can choose which split from your `splits.pkl` you want to use for the training.\n\nYou may also need to adapt the paths (`data_root_dir, data_dir, data_test_dir and split_dir`).\n\nYou can change the `Logging parameters` if you want to. With `append_rnd_string`, you can give each experiment you start a unique name.\nIf you want to start your visdom server manually, just set `start_visdom=False`. If you do not want to use visdom logging at all, just remove the visdom logger from your\n experiment, e.g. `run_train_pipeline.py` line 47:\n \n ```\n loggers={\n       \"visdom\": (\"visdom\", {\"auto_start\": c.start_visdom})\n }\n ```\n\n### Datasets\nIf you want to use the provided DataLoader, you need to preprocess your data appropriately. An example can be found in the \n\"example_dataset\" folder. Make sure to load your images and your labels as numpy arrays. The required shape is `(#slices, w,h)`. \nThen save both using:\n```\nresult = np.stack((image, label))\n\nnp.save(output_filename, result)\n```\n\nThe provided DataLoader requires a splits.pkl file, that contains a dictionary of all the files used for training, validation and testing.\nIt looks like this:\n```\n[{'train': ['dataset_name_1',...], 'val': ['dataset_name_2', ...], 'test': ['dataset_name_3', ...]}]\n```\n\nWe use the `MIC/batchgenerators` to perform data augmentation. The example uses cropping, mirroring and some elastic spatial transformation.\nYou can change the data augmentation by editing the `data_augmentation.py`. Please see the `MIC/batchgenerators` documentation for more details.\n\nTo train your network, simply run\n```\npython train.py\n```\n\nYou can either edit the config file or add command line parameters like this:\n```\npython train.py --n_epochs 100 [...]\n```\n\n## Networks\nThis example contains a simple implementation of the U-Net [1], which can be found in `networks>UNET.py`. \nA little more generic version of the U-Net, as well as the 3D U-Net [3], can be found in `networks>RecursiveUNet.py` \nrespectively `networks>RecursiveUNet3D.py`. This implementation is done recursively.\nIt is therefore very easy to configure the number of downsamplings. Also, the type of normalization can be passed as a parameter (default is nn.InstanceNorm2d).\n\n## Errors and how to handle them\nIn this section, we want to collect common errors that may occur when using this repository.\nIf you encounter something, feel free to let us know about it and we will include it here.\n\n### Windows related issues\n\nIf you want to use this repo on Windows, please note, that you have to adapt to some things.\nWe recommend to install PyTorch via conda on Windows using: `python -m conda install pytorch torchvision cpuonly -c pytorch`\nYou then have to remove torch from the requirements.txt.\n\nIf you run into issues like the following one:\n\n ```\nAttributeError: Can't pickle local object 'MultiThreadedDataLoader.get_worker_init_fn.<locals>.init_fn'`\n ```\n\ntry to use SingleProcessDataLoader instead. This error is probably caused by how multithreading is handled in python on Windows.\nSo fix this, add `num_processes=0` to your dataloaders:\n\n ```\nself.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, \n                                        batch_size=self.config.batch_size, keys=tr_keys, num_processes=0)\nself.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, \n                                        batch_size=self.config.batch_size, keys=val_keys, mode=\"val\", do_reshuffle=False, num_processes=0)\nself.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, \n                                        batch_size=self.config.batch_size, keys=test_keys, mode=\"test\", do_reshuffle=False, num_processes=0)\n ```\n\n### Multiple Labels\nDepending on your dataset you might be dealing with multiple labels. For example the\ndata from BRATS (https://www.med.upenn.edu/sbia/brats2017.html) has the following labels:\n ```\n \"labels\": {\n\t \"0\": \"background\",\n\t \"1\": \"edema\",\n\t \"2\": \"non-enhancing tumor\",\n\t \"3\": \"enhancing tumour\"\n },\n ```\n* If you run into an error like this:\n    ```\n    Experiment exited. Checkpoints stored =)\n    INFO:default-z3HafHO4CS:Experiment exited. Checkpoints stored =)\n    Unhandled exception in thread started by <function PytorchExperimentLogger.save_checkpoint_static at 0x7fd07c3e8510>\n    Traceback (most recent call last):\n      File \"/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 196, in save_checkpoint_static\n       torch.save(to_cpu(kwargs), checkpoint_file)\n      File \"/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 191, in to_cpu\n        return {key: to_cpu(val) for key, val in obj.items()}\n      File \"//python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 191, in <dictcomp>\n        return {key: to_cpu(val) for key, val in obj.items()}\n      File \"/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 191, in to_cpu\n        return {key: to_cpu(val) for key, val in obj.items()}\n      File \"/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 191, in <dictcomp>\n        return {key: to_cpu(val) for key, val in obj.items()}\n      File \"/python3.5/site-packages/trixi/logger/experiment/pytorchexperimentlogger.py\", line 189, in to_cpu\n        return obj.cpu()\n    RuntimeError: CUDA error: device-side assert triggered\n    ```\n    make sure you updated `num_classes` in your config file. The value of `num_classes` should always\n    equal the number of your labels including background.\n\n* If you run into an error like this:\n    ```\n    File \"/home/student/basic_unet/trixi/trixi/experiment/experiment.py\", line 108, in run\n      self.process_err(e)\n    File \"/home/student/basic_unet/trixi/trixi/experiment/pytorchexperiment.py\", line 391, in process_err\n      raise e\n    File \"/home/student/basic_unet/trixi/trixi/experiment/experiment.py\", line 89, in run\n      self.train(epoch=self._epoch_idx)\n    File \"/home/student/PycharmProjects/new_unet/experiments/UNetExperiment.py\", line 113, in train\n      loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())\n    File \"/opt/anaconda3/envs/a_new_test/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 493, in call\n      result = self.forward(input, *kwargs)\n    File \"/home/student/PycharmProjects/new_unet/loss_functions/dice_loss.py\", line 125, in forward\n      yonehot.scatter(1, y, 1)\n    RuntimeError: Invalid index in scatter at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:551\n    ```\n    make sure to check your labels again. The error may be caused by the fact that the labels are not sequential. This causes `scatter` to crash. Consider changing the values of      your labels.\n\n## References\n[1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" \nInternational Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.\n[2] David Zimmerer, Jens Petersen, GregorKoehler, Jakob Wasserthal, dzimmm, Tim, … André Pequeño. (2018, November 23). MIC-DKFZ/trixi: Alpha (Version v0.1.1). \nZenodo. http://doi.org/10.5281/zenodo.1495180\n[3] Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" \nInternational conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.\n\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "configs/Config_unet.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nfrom trixi.util import Config\n\n\ndef get_config():\n    # Set your own path, if needed.\n    data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.\n\n    c = Config(\n        update_from_argv=True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.\n\n        # Train parameters\n        num_classes=3,\n        in_channels=1,\n        batch_size=8,\n        patch_size=64,\n        n_epochs=10,\n        learning_rate=0.0002,\n        fold=0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.\n\n        device=\"cuda\",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html\n\n        # Logging parameters\n        name='Basic_Unet',\n        author='kleina',  # Author of this project\n        plot_freq=10,  # How often should stuff be shown in visdom\n        append_rnd_string=False,  # Appends a random string to the experiment name to make it unique.\n        start_visdom=True,  # You can either start a visom server manually or have trixi start it for you.\n\n        do_instancenorm=True,  # Defines whether or not the UNet does a instance normalization in the contracting path\n        do_load_checkpoint=False,\n        checkpoint_dir='',\n\n        # Adapt to your own path, if needed.\n        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',  # This id is used to download the example dataset.\n        dataset_name='Task04_Hippocampus',\n        base_dir=os.path.abspath('output_experiment'),  # Where to log the output of the experiment.\n\n        data_root_dir=data_root_dir,  # The path where the downloaded dataset is stored.\n        data_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your training and validation data is stored\n        data_test_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your test data is stored\n\n        split_dir=os.path.join(data_root_dir, 'Task04_Hippocampus'),  # This is where the 'splits.pkl' file is located, that holds your splits.\n\n        # execute a segmentation process on a specific image using the model\n        model_dir=os.path.join(os.path.abspath('output_experiment'), ''),  # the model being used for segmentation\n    )\n\n    print(c)\n    return c\n"
  },
  {
    "path": "configs/Config_unet_spleen.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nfrom trixi.util import Config\n\n\ndef get_config():\n\n    # Set your own path, if needed.\n    data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.\n\n    c = Config(\n        update_from_argv=True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.\n\n        # Train parameters\n        num_classes=2,\n        in_channels=1,\n        batch_size=3,       # works with 6 on GB GPU\n        patch_size=512,\n        n_epochs=1,\n        learning_rate=0.0002,\n        fold=0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.\n\n        device=\"cuda\",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html\n\n        # Logging parameters\n        name='Basic_Unet',\n        author='kleina',  # Author of this project\n        plot_freq=10,  # How often should stuff be shown in visdom\n        append_rnd_string=False,  # Appends a random string to the experiment name to make it unique.\n        start_visdom=True,  # You can either start a visom server manually or have trixi start it for you.\n\n        do_instancenorm=True,  # Defines whether or not the UNet does a instance normalization in the contracting path\n        do_load_checkpoint=False,\n        checkpoint_dir='',\n\n        # Adapt to your own path, if needed.\n        google_drive_id='1jzeNU1EKnK81PyTsrx0ujfNl-t0Jo8uE', #spleen\n        dataset_name='Task09_Spleen',\n        base_dir=os.path.abspath('output_experiment'),  # Where to log the output of the experiment.\n\n        data_root_dir=data_root_dir,  # The path where the downloaded dataset is stored.\n        data_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'),  # This is where your training and validation data is stored\n        data_test_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'),  # This is where your test data is stored\n\n        split_dir=os.path.join(data_root_dir, 'Task09_Spleen'),  # This is where the 'splits.pkl' file is located, that holds your splits.\n\n        # execute a segmentation process on a specific image using the model\n        model_dir=os.path.join(os.path.abspath('output_experiment'), '20200108-035420_Basic_Unet/checkpoint/checkpoint_current'),   # the model being used for segmentation\n    )\n\n    print(c)\n    return c\n"
  },
  {
    "path": "configs/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/data_loader.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom torch.utils.data import DataLoader, Dataset\nfrom trixi.util.pytorchutils import set_seed\n\n\nclass WrappedDataset(Dataset):\n    def __init__(self, dataset, transform):\n        self.transform = transform\n        self.dataset = dataset\n\n        self.is_indexable = False\n        if hasattr(self.dataset, \"__getitem__\") and not (hasattr(self.dataset, \"use_next\") and self.dataset.use_next is True):\n            self.is_indexable = True\n\n    def __getitem__(self, index):\n\n        if not self.is_indexable:\n            item = next(self.dataset)\n        else:\n            item = self.dataset[index]\n        item = self.transform(**item)\n        return item\n\n    def __len__(self):\n        return int(self.dataset.num_batches)\n\n\nclass MultiThreadedDataLoader(object):\n    def __init__(self, data_loader, transform, num_processes, **kwargs):\n\n        self.cntr = 1\n        self.ds_wrapper = WrappedDataset(data_loader, transform)\n\n        self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,\n                                    num_workers=num_processes, pin_memory=True, drop_last=False,\n                                    worker_init_fn=self.get_worker_init_fn())\n\n        self.num_processes = num_processes\n        self.iter = None\n\n    def get_worker_init_fn(self):\n        def init_fn(worker_id):\n            set_seed(worker_id + self.cntr)\n\n        return init_fn\n\n    def __iter__(self):\n        self.kill_iterator()\n        self.iter = iter(self.generator)\n        return self.iter\n\n    def __next__(self):\n        if self.iter is None:\n            self.iter = iter(self.generator)\n        return next(self.iter)\n\n    def renew(self):\n        self.cntr += 1\n        self.kill_iterator()\n        self.generator.worker_init_fn = self.get_worker_init_fn()\n        self.iter = iter(self.generator)\n\n    def restart(self):\n        pass\n        # self.iter = iter(self.generator)\n\n    def kill_iterator(self):\n        try:\n            if self.iter is not None:\n                self.iter._shutdown_workers()\n                for p in self.iter.workers:\n                    p.terminate()\n        except:\n            print(\"Could not kill Dataloader Iterator\")\n"
  },
  {
    "path": "datasets/example_dataset/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/example_dataset/create_splits.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nfrom utilities.file_and_folder_operations import subfiles\n\nimport os\nimport random\n\n\ndef create_splits(output_dir, image_dir):\n    \"\"\"File to split the dataset into multiple folds and the train, validation and test set.\n\n    :param output_dir: Directory to write the splits file to\n    :param image_dir: Directory where the images lie in.\n    \"\"\"\n    npy_files = subfiles(image_dir, suffix=\".npy\", join=False)\n    sample_size = len(npy_files)\n\n    testset_size = int(sample_size * 0.25)\n    valset_size = int(sample_size * 0.25)\n    trainset_size = sample_size - valset_size - testset_size  # Assure all samples are used.\n\n    if sample_size < (trainset_size + valset_size + testset_size):\n        raise ValueError(\"Assure more total samples exist than train test and val samples combined!\")\n\n    splits = []\n    sample_set = {sample[:-4] for sample in npy_files.copy()}  # Remove the file extension\n    test_samples = random.sample(sample_set, testset_size)  # IMO the Testset should be static for all splits\n\n    for split in range(0, 5):\n        train_samples = random.sample(sample_set - set(test_samples), trainset_size)\n        val_samples = list(sample_set - set(train_samples) - set(test_samples))\n\n        train_samples.sort()\n        val_samples.sort()\n\n        split_dict = dict()\n        split_dict['train'] = train_samples\n        split_dict['val'] = val_samples\n        split_dict['test'] = test_samples\n\n        splits.append(split_dict)\n\n    # Todo: IMO it is better to write that dict as JSON. This (unlike pickle) allows the user to inspect the file with an editor\n    with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:\n        pickle.dump(splits, f)\n\n    splits_sanity_check(output_dir)\n\n\n# ToDo: The naming \"splits.pkl should not be distributed over multiple files. This makes changing of it less clear.\n#   Instead move saving and loading to one file. (Here would be a good place)\n#   Other usages are: spleen/create_splits.py:57 (Which is redundand anyways?);\n#   UNetExperiment3D.py:55  and UNetExperiment.py:55\ndef splits_sanity_check(path):\n    \"\"\" Takes path to a splits file and verifies that no samples from the test dataset leaked into train or validation.\n    :param path\n    \"\"\"\n    with open(os.path.join(path, 'splits.pkl'), 'rb') as f:\n        splits = pickle.load(f)\n        for i in range(len(splits)):\n            samples = splits[i]\n            tr_samples = set(samples[\"train\"])\n            vl_samples = set(samples[\"val\"])\n            ts_samples = set(samples[\"test\"])\n\n            assert len(tr_samples.intersection(vl_samples)) == 0, \"Train and validation samples overlap!\"\n            assert len(vl_samples.intersection(ts_samples)) == 0, \"Validation and Test samples overlap!\"\n            assert len(tr_samples.intersection(ts_samples)) == 0, \"Train and Test samples overlap!\"\n    return\n"
  },
  {
    "path": "datasets/example_dataset/preprocessing.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nfrom batchgenerators.augmentations.utils import pad_nd_image\nfrom medpy.io import load\nimport os\nimport numpy as np\nimport torch\n\nfrom utilities.file_and_folder_operations import subfiles\n\n\ndef preprocess_data(root_dir, y_shape=64, z_shape=64):\n    image_dir = os.path.join(root_dir, 'imagesTr')\n    label_dir = os.path.join(root_dir, 'labelsTr')\n    output_dir = os.path.join(root_dir, 'preprocessed')\n    classes = 3\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n        print('Created' + output_dir + '...')\n\n    class_stats = defaultdict(int)\n    total = 0\n\n    nii_files = subfiles(image_dir, suffix=\".nii.gz\", join=False)\n\n    for i in range(0, len(nii_files)):\n        if nii_files[i].startswith(\"._\"):\n            nii_files[i] = nii_files[i][2:]\n\n    for f in nii_files:\n        image, _ = load(os.path.join(image_dir, f))\n        label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))\n\n        print(f)\n\n        for i in range(classes):\n            class_stats[i] += np.sum(label == i)\n            total += np.sum(label == i)\n\n        # normalize images\n        image = (image - image.min())/(image.max()-image.min())\n\n        image = pad_nd_image(image, (image.shape[0], y_shape, z_shape), \"constant\", kwargs={'constant_values': image.min()})\n        label = pad_nd_image(label, (image.shape[0], y_shape, z_shape), \"constant\", kwargs={'constant_values': label.min()})\n\n        result = np.stack((image, label))\n\n        np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)\n        print(f)\n\n    print(total)\n    for i in range(classes):\n        print(class_stats[i], class_stats[i]/total)\n\n\ndef preprocess_single_file(image_file):\n    image, image_header = load(image_file)\n    image = (image - image.min()) / (image.max() - image.min())\n\n    data = np.expand_dims(image, 1)\n\n    return torch.from_numpy(data), image_header\n\n\ndef postprocess_single_image(image):\n    # desired shape is [b w h]\n    result_converted = image[::, 0, ::, ::]\n    result_mapped = [i * 255 for i in result_converted]\n\n    return result_mapped\n"
  },
  {
    "path": "datasets/spleen/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/spleen/create_splits.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nfrom utilities.file_and_folder_operations import subfiles\n\nimport os\nimport numpy as np\n\n\ndef create_splits(output_dir, image_dir):\n    npy_files = subfiles(image_dir, suffix=\".npy\", join=False)\n\n    trainset_size = len(npy_files)*50//100\n    valset_size = len(npy_files)*25//100\n    testset_size = len(npy_files)*25//100\n\n    splits = []\n    for split in range(0, 5):\n        image_list = npy_files.copy()\n        trainset = []\n        valset = []\n        testset = []\n        for i in range(0, trainset_size):\n            patient = np.random.choice(image_list)\n            image_list.remove(patient)\n            trainset.append(patient[:-4])\n        for i in range(0, valset_size):\n            patient = np.random.choice(image_list)\n            image_list.remove(patient)\n            valset.append(patient[:-4])\n        for i in range(0, testset_size):\n            patient = np.random.choice(image_list)\n            image_list.remove(patient)\n            testset.append(patient[:-4])\n        split_dict = dict()\n        split_dict['train'] = trainset\n        split_dict['val'] = valset\n        split_dict['test'] = testset\n\n        splits.append(split_dict)\n\n    with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f:\n        pickle.dump(splits, f)\n"
  },
  {
    "path": "datasets/spleen/preprocessing.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nfrom medpy.io import load\nimport os\nimport numpy as np\n\nfrom utilities.file_and_folder_operations import subfiles\nimport torch\n\n\ndef preprocess_data(root_dir, y_shape=64, z_shape=64):\n    image_dir = os.path.join(root_dir, 'imagesTr')\n    label_dir = os.path.join(root_dir, 'labelsTr')\n    output_dir = os.path.join(root_dir, 'preprocessed')\n\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir)\n        print('Created' + output_dir + '...')\n\n    class_stats = defaultdict(int)\n    total = 0\n\n    nii_files = subfiles(image_dir, suffix=\".nii.gz\", join=False)\n\n    for i in range(0, len(nii_files)):\n        if nii_files[i].startswith(\"._\"):\n            nii_files[i] = nii_files[i][2:]\n\n    for f in nii_files:\n        image, _ = load(os.path.join(image_dir, f))\n        label, _ = load(os.path.join(label_dir, f.replace('_0000', '')))\n\n        print(f)\n\n        # normalize images\n        image = (image - image.min())/(image.max()-image.min())\n\n        image = np.swapaxes(image, 0, 2)\n        image = np.swapaxes(image, 1, 2)\n\n        label = np.swapaxes(label, 0, 2)\n        label = np.swapaxes(label, 1, 2)\n        result = np.stack((image, label))\n\n        np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result)\n        print(f)\n\n    print(total)\n\n\ndef preprocess_single_file(image_file):\n    image, image_header = load(image_file)\n    image = (image - image.min()) / (image.max() - image.min())\n\n    image = np.swapaxes(image, 0, 2)\n    image = np.swapaxes(image, 1, 2)\n\n    # TODO check original shape and reshape data if necessary\n    # image = reshape(image, append_value=0, new_shape=(image.shape[0], y_shape, z_shape))\n    # numpy_array = np.array(image)\n\n    # Image shape is [b, w, h] and has one channel only\n    # Desired shape = [b, c, w, h]\n    # --> expand to have only one channel c=1 - data is in desired shape\n    data = np.expand_dims(image, 1)\n\n    return torch.from_numpy(data), image_header\n\n\ndef postprocess_single_image(image):\n    # desired shape is [b w h]\n    result_converted = image[::, 0, ::, ::]\n    result_mapped = [i * 255 for i in result_converted]\n\n    # swap axes back, like we were supposed to do so\n    result_mapped = np.swapaxes(result_mapped, 2, 1)\n    result_mapped = np.swapaxes(result_mapped, 2, 0)\n\n    return result_mapped\n"
  },
  {
    "path": "datasets/three_dim/NumpyDataLoader.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport fnmatch\nimport random\n\nimport numpy as np\n\nfrom batchgenerators.dataloading import SlimDataLoaderBase\nfrom datasets.data_loader import MultiThreadedDataLoader\nfrom .data_augmentation import get_transforms\n\n\ndef load_dataset(base_dir, pattern='*.npy', keys=None):\n    fls = []\n    files_len = []\n    dataset = []\n\n    for root, dirs, files in os.walk(base_dir):\n        i = 0\n        for filename in sorted(fnmatch.filter(files, pattern)):\n\n            if keys is not None and filename[:-4] in keys:\n                npy_file = os.path.join(root, filename)\n                numpy_array = np.load(npy_file, mmap_mode=\"r\")\n\n                fls.append(npy_file)\n                files_len.append(numpy_array.shape[1])\n\n                dataset.extend([i])\n\n                i += 1\n\n    return fls, files_len, dataset\n\n\nclass NumpyDataSet(object):\n    \"\"\"\n    TODO\n    \"\"\"\n    def __init__(self, base_dir, mode=\"train\", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,\n                 file_pattern='*.npy', label=1, input=(0,), do_reshuffle=True, keys=None):\n\n        data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern,\n                                      input=input, label=label, keys=keys)\n\n        self.data_loader = data_loader\n        self.batch_size = batch_size\n        self.do_reshuffle = do_reshuffle\n        self.number_of_slices = 1\n\n        self.transforms = get_transforms(mode=mode, target_size=target_size)\n        self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,\n                                                 num_cached_per_queue=num_cached_per_queue, seeds=seed,\n                                                 shuffle=do_reshuffle)\n        self.augmenter.restart()\n\n    def __len__(self):\n        return len(self.data_loader)\n\n    def __iter__(self):\n        if self.do_reshuffle:\n            self.data_loader.reshuffle()\n        self.augmenter.renew()\n        return self.augmenter\n\n    def __next__(self):\n        return next(self.augmenter)\n\n\nclass NumpyDataLoader(SlimDataLoaderBase):\n    def __init__(self, base_dir, mode=\"train\", batch_size=16, num_batches=10000000,\n                 seed=None, file_pattern='*.npy', label=1, input=(0,), keys=None):\n\n        self.files, self.file_len, self.dataset = load_dataset(base_dir=base_dir, pattern=file_pattern, keys=keys, )\n        super(NumpyDataLoader, self).__init__(self.dataset, batch_size, num_batches)\n\n        self.batch_size = batch_size\n\n        self.use_next = False\n        if mode == \"train\":\n            self.use_next = False\n\n        self.idxs = list(range(0, len(self.dataset)))\n\n        self.data_len = len(self.dataset)\n\n        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)\n\n        if isinstance(label, int):\n            label = (label,)\n        self.input = input\n        self.label = label\n\n        self.np_data = np.asarray(self.dataset)\n\n    def reshuffle(self):\n        print(\"Reshuffle...\")\n        random.shuffle(self.idxs)\n        print(\"Initializing... this might take a while...\")\n\n    def generate_train_batch(self):\n        open_arr = random.sample(self._data, self.batch_size)\n        return self.get_data_from_array(open_arr)\n\n    def __len__(self):\n        n_items = min(self.data_len // self.batch_size, self.num_batches)\n        return n_items\n\n    def __getitem__(self, item):\n        idxs = self.idxs\n        data_len = len(self.dataset)\n        np_data = self.np_data\n\n        if item > len(self):\n            raise StopIteration()\n        if (item * self.batch_size) == data_len:\n            raise StopIteration()\n\n        start_idx = (item * self.batch_size) % data_len\n        stop_idx = ((item + 1) * self.batch_size) % data_len\n\n        if ((item + 1) * self.batch_size) == data_len:\n            stop_idx = data_len\n\n        if stop_idx > start_idx:\n            idxs = idxs[start_idx:stop_idx]\n        else:\n            raise StopIteration()\n\n        open_arr = np_data[idxs]\n\n        return self.get_data_from_array(open_arr)\n\n    def get_data_from_array(self, open_array):\n        data = []\n        fnames = []\n        idxs = []\n        labels = []\n\n        for idx in open_array:\n            fn_name = self.files[idx]\n\n            numpy_array = np.load(fn_name, mmap_mode=\"r\")\n\n            data.append(numpy_array[list(self.input)])   # 'None' keeps the dimension\n\n            if self.label is not None:\n                labels.append(numpy_array[list(self.input)])   # 'None' keeps the dimension\n\n            fnames.append(self.files[idx])\n            idxs.append(idx)\n\n        ret_dict = {'data': data, 'fnames': fnames, 'idxs': idxs}\n        if self.label is not None:\n            ret_dict['seg'] = labels\n\n        return ret_dict\n"
  },
  {
    "path": "datasets/three_dim/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/three_dim/data_augmentation.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom batchgenerators.transforms import Compose, MirrorTransform\nfrom batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform\nfrom batchgenerators.transforms.utility_transforms import NumpyToTensor\n\n\ndef get_transforms(mode=\"train\", target_size=128):\n    transform_list = []\n\n    if mode == \"train\":\n        transform_list = [CenterCropTransform(crop_size=target_size),\n                          ResizeTransform(target_size=target_size, order=1),\n                          MirrorTransform(axes=(2,)),\n                          SpatialTransform(patch_size=(target_size, target_size, target_size), random_crop=False,\n                                           patch_center_dist_from_border=target_size // 2,\n                                           do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),\n                                           do_rotation=True,\n                                           angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),\n                                           scale=(0.9, 1.4),\n                                           border_mode_data=\"nearest\", border_mode_seg=\"nearest\"),\n                          ]\n\n    elif mode == \"val\":\n        transform_list = [CenterCropTransform(crop_size=target_size),\n                          ResizeTransform(target_size=target_size, order=1),\n                          ]\n\n    elif mode == \"test\":\n        transform_list = [CenterCropTransform(crop_size=target_size),\n                          ResizeTransform(target_size=target_size, order=1),\n                          ]\n\n    transform_list.append(NumpyToTensor())\n\n    return Compose(transform_list)\n"
  },
  {
    "path": "datasets/two_dim/NumpyDataLoader.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport fnmatch\nimport random\n\nimport numpy as np\n\nfrom batchgenerators.dataloading import SlimDataLoaderBase\nfrom datasets.data_loader import MultiThreadedDataLoader\nfrom .data_augmentation import get_transforms\n\n\ndef load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None):\n    fls = []\n    files_len = []\n    slices_ax = []\n\n    for root, dirs, files in os.walk(base_dir):\n        i = 0\n        for filename in sorted(fnmatch.filter(files, pattern)):\n\n            if keys is not None and filename[:-4] in keys:\n                npy_file = os.path.join(root, filename)\n                numpy_array = np.load(npy_file, mmap_mode=\"r\")\n\n                fls.append(npy_file)\n                files_len.append(numpy_array.shape[1])\n\n                slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)])\n\n                i += 1\n\n    return fls, files_len, slices_ax,\n\n\nclass NumpyDataSet(object):\n    \"\"\"\n    TODO\n    \"\"\"\n    def __init__(self, base_dir, mode=\"train\", batch_size=16, num_batches=10000000, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,\n                 file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None):\n\n        data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, file_pattern=file_pattern,\n                                      input_slice=input_slice, label_slice=label_slice, keys=keys)\n\n        self.data_loader = data_loader\n        self.batch_size = batch_size\n        self.do_reshuffle = do_reshuffle\n        self.number_of_slices = 1\n\n        self.transforms = get_transforms(mode=mode, target_size=target_size)\n        self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes,\n                                                 num_cached_per_queue=num_cached_per_queue,\n                                                 shuffle=do_reshuffle)\n        self.augmenter.restart()\n\n    def __len__(self):\n        return len(self.data_loader)\n\n    def __iter__(self):\n        if self.do_reshuffle:\n            self.data_loader.reshuffle()\n        self.augmenter.renew()\n        return self.augmenter\n\n    def __next__(self):\n        return next(self.augmenter)\n\n\nclass NumpyDataLoader(SlimDataLoaderBase):\n    def __init__(self, base_dir, mode=\"train\", batch_size=16, num_batches=10000000,\n                 file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None):\n\n        self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, )\n        super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches)\n\n        self.batch_size = batch_size\n\n        self.use_next = False\n        if mode == \"train\":\n            self.use_next = False\n\n        self.slice_idxs = list(range(0, len(self.slices)))\n\n        self.data_len = len(self.slices)\n\n        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)\n\n        if isinstance(label_slice, int):\n            label_slice = (label_slice,)\n        self.input_slice = input_slice\n        self.label_slice = label_slice\n\n        self.np_data = np.asarray(self.slices)\n\n    def reshuffle(self):\n        print(\"Reshuffle...\")\n        random.shuffle(self.slice_idxs)\n        print(\"Initializing... this might take a while...\")\n\n    def generate_train_batch(self):\n        open_arr = random.sample(self._data, self.batch_size)\n        return self.get_data_from_array(open_arr)\n\n    def __len__(self):\n        n_items = min(self.data_len // self.batch_size, self.num_batches)\n        return n_items\n\n    def __getitem__(self, item):\n        slice_idxs = self.slice_idxs\n        data_len = len(self.slices)\n        np_data = self.np_data\n\n        if item > len(self):\n            raise StopIteration()\n        if (item * self.batch_size) == data_len:\n            raise StopIteration()\n\n        start_idx = (item * self.batch_size) % data_len\n        stop_idx = ((item + 1) * self.batch_size) % data_len\n\n        if ((item + 1) * self.batch_size) == data_len:\n            stop_idx = data_len\n\n        if stop_idx > start_idx:\n            idxs = slice_idxs[start_idx:stop_idx]\n        else:\n            raise StopIteration()\n\n        open_arr = np_data[idxs]\n\n        return self.get_data_from_array(open_arr)\n\n    def get_data_from_array(self, open_array):\n        data = []\n        fnames = []\n        slice_idxs = []\n        labels = []\n\n        for slice in open_array:\n            fn_name = self.files[slice[0]]\n\n            numpy_array = np.load(fn_name, mmap_mode=\"r\")\n\n            numpy_slice = numpy_array[:, slice[1], ]\n            data.append(numpy_slice[list(self.input_slice)])   # 'None' keeps the dimension\n\n            if self.label_slice is not None:\n                labels.append(numpy_slice[list(self.label_slice)])   # 'None' keeps the dimension\n\n            fnames.append(self.files[slice[0]])\n            slice_idxs.append(slice[1])\n\n        ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs}\n        if self.label_slice is not None:\n            ret_dict['seg'] = np.asarray(labels)\n\n        return ret_dict\n"
  },
  {
    "path": "datasets/two_dim/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/two_dim/data_augmentation.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom batchgenerators.transforms import Compose, MirrorTransform\nfrom batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform\nfrom batchgenerators.transforms.utility_transforms import NumpyToTensor\n\nimport numpy as np\n\n\ndef get_transforms(mode=\"train\", target_size=128):\n    tranform_list = []\n\n    if mode == \"train\":\n        tranform_list = [# CenterCropTransform(crop_size=target_size),\n                         ResizeTransform(target_size=(target_size,target_size), order=1),\n                         MirrorTransform(axes=(1,)),\n                         SpatialTransform(patch_size=(target_size, target_size), random_crop=False,\n                                          patch_center_dist_from_border=target_size // 2,\n                                          do_elastic_deform=True, alpha=(0., 900.), sigma=(20., 30.),\n                                          do_rotation=True, p_rot_per_sample=0.8,\n                                          angle_x=(-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), angle_y=(0, 1e-8), angle_z=(0, 1e-8),\n                                          scale=(0.85, 1.25), p_scale_per_sample=0.8,\n                                          border_mode_data=\"nearest\", border_mode_seg=\"nearest\"),\n                         ]\n\n\n    elif mode == \"val\":\n        tranform_list = [# CenterCropTransform(crop_size=target_size),\n                         ResizeTransform(target_size=target_size, order=1),\n                         ]\n\n    elif mode == \"test\":\n        tranform_list = [# CenterCropTransform(crop_size=target_size),\n                         ResizeTransform(target_size=target_size, order=1),\n                         ]\n\n    tranform_list.append(NumpyToTensor())\n\n    return Compose(tranform_list)\n"
  },
  {
    "path": "datasets/utils.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom os.path import exists\nimport tarfile\n\nfrom google_drive_downloader import GoogleDriveDownloader as gdd\n\n\ndef download_dataset(dest_path, dataset, id=''):\n    if not exists(os.path.join(dest_path, dataset)):\n        tar_path = os.path.join(dest_path, dataset) + '.tar'\n        gdd.download_file_from_google_drive(file_id=id,\n                                            dest_path=tar_path, overwrite=False,\n                                            unzip=False)\n\n        print('Extracting data [STARTED]')\n        tar = tarfile.open(tar_path)\n        tar.extractall(dest_path)\n        print('Extracting data [DONE]')\n    else:\n        print('Data already downloaded. Files are not extracted again.')\n        print('Data already downloaded. Files are not extracted again.')\n\n    return\n"
  },
  {
    "path": "evaluation/__init__.py",
    "content": ""
  },
  {
    "path": "evaluation/evaluator.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport collections\nimport inspect\nimport json\nimport hashlib\nfrom datetime import datetime\nimport numpy as np\nimport pandas as pd\nimport SimpleITK as sitk\nfrom evaluation.metrics import ConfusionMatrix, ALL_METRICS\n\n\nclass Evaluator:\n    \"\"\"Object that holds test and reference segmentations with label information\n    and computes a number of metrics on the two. 'labels' must either be an\n    iterable of numeric values (or tuples thereof) or a dictionary with string\n    names and numeric values.\n    \"\"\"\n\n    default_metrics = [\n        \"False Positive Rate\",\n        \"Dice\",\n        \"Jaccard\",\n        \"Precision\",\n        \"Recall\",\n        \"Accuracy\",\n        \"False Omission Rate\",\n        \"Negative Predictive Value\",\n        \"False Negative Rate\",\n        \"True Negative Rate\",\n        \"False Discovery Rate\",\n        \"Total Positives Test\",\n        \"Total Positives Reference\"\n    ]\n\n    default_advanced_metrics = [\n        \"Hausdorff Distance\",\n        \"Hausdorff Distance 95\",\n        \"Avg. Surface Distance\",\n        \"Avg. Symmetric Surface Distance\"\n    ]\n\n    def __init__(self,\n                 test=None,\n                 reference=None,\n                 labels=None,\n                 metrics=None,\n                 advanced_metrics=None,\n                 nan_for_nonexisting=True):\n\n        self.test = None\n        self.reference = None\n        self.confusion_matrix = ConfusionMatrix()\n        self.labels = None\n        self.nan_for_nonexisting = nan_for_nonexisting\n        self.result = None\n\n        self.metrics = []\n        if metrics is None:\n            for m in self.default_metrics:\n                self.metrics.append(m)\n        else:\n            for m in metrics:\n                self.metrics.append(m)\n\n        self.advanced_metrics = []\n        if advanced_metrics is None:\n            for m in self.default_advanced_metrics:\n                self.advanced_metrics.append(m)\n        else:\n            for m in advanced_metrics:\n                self.advanced_metrics.append(m)\n\n        self.set_reference(reference)\n        self.set_test(test)\n        if labels is not None:\n            self.set_labels(labels)\n        else:\n            if test is not None and reference is not None:\n                self.construct_labels()\n\n    def set_test(self, test):\n        \"\"\"Set the test segmentation.\"\"\"\n\n        self.test = test\n\n    def set_reference(self, reference):\n        \"\"\"Set the reference segmentation.\"\"\"\n\n        self.reference = reference\n\n    def set_labels(self, labels):\n        \"\"\"Set the labels.\n        :param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels\n        will only have names if you pass a dictionary\"\"\"\n\n        if not isinstance(labels, (dict, set, list, tuple)):\n            raise ValueError(\"Labels must be either list, tuple, set or dict\")\n        elif isinstance(labels, dict):\n            self.labels = collections.OrderedDict(labels)\n        elif isinstance(labels, set):\n            self.labels = list(labels)\n        elif isinstance(labels, (list, tuple)):\n            self.labels = labels\n        else:\n            raise TypeError(\"Can only handle dict, list, tuple, set & numpy array, but input is of type {}\".format(type(labels)))\n\n    def construct_labels(self):\n        \"\"\"Construct label set from unique entries in segmentations.\"\"\"\n\n        if self.test is None and self.reference is None:\n            raise ValueError(\"No test or reference segmentations.\")\n        elif self.test is None:\n            labels = np.unique(self.reference)\n        else:\n            labels = np.union1d(np.unique(self.test),\n                                np.unique(self.reference))\n        self.labels = list(map(lambda x: int(x), labels))\n\n    def set_metrics(self, metrics):\n        \"\"\"Set evaluation metrics\"\"\"\n\n        if isinstance(metrics, set):\n            self.metrics = list(metrics)\n        elif isinstance(metrics, (list, tuple, np.ndarray)):\n            self.metrics = metrics\n        else:\n            raise TypeError(\"Can only handle list, tuple, set & numpy array, but input is of type {}\".format(type(metrics)))\n\n    def add_metric(self, metric):\n\n        if metric not in self.metrics:\n            self.metrics.append(metric)\n\n    def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs):\n        \"\"\"Compute metrics for segmentations.\"\"\"\n        if test is not None:\n            self.set_test(test)\n\n        if reference is not None:\n            self.set_reference(reference)\n\n        if self.test is None or self.reference is None:\n            raise ValueError(\"Need both test and reference segmentations.\")\n\n        if self.labels is None:\n            self.construct_labels()\n\n        self.metrics.sort()\n\n        # get functions for evaluation\n        # somewhat convoluted, but allows users to define additonal metrics\n        # on the fly, e.g. inside an IPython console\n        _funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics}\n        frames = inspect.getouterframes(inspect.currentframe())\n        for metric in self.metrics:\n            for f in frames:\n                if metric in f[0].f_locals:\n                    _funcs[metric] = f[0].f_locals[metric]\n                    break\n            else:\n                if metric in _funcs:\n                    continue\n                else:\n                    raise NotImplementedError(\n                        \"Metric {} not implemented.\".format(metric))\n\n        # get results\n        self.result = {}\n\n        eval_metrics = self.metrics\n        if advanced:\n            eval_metrics += self.advanced_metrics\n\n        if isinstance(self.labels, dict):\n\n            for label, name in self.labels.items():\n                self.result[name] = {}\n                if not hasattr(label, \"__iter__\"):\n                    self.confusion_matrix.set_test(self.test == label)\n                    self.confusion_matrix.set_reference(self.reference == label)\n                else:\n                    current_test = 0\n                    current_reference = 0\n                    for l in label:\n                        current_test += (self.test == l)\n                        current_reference += (self.reference == l)\n                    self.confusion_matrix.set_test(current_test)\n                    self.confusion_matrix.set_reference(current_reference)\n                for metric in eval_metrics:\n                    self.result[name][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,\n                                                               nan_for_nonexisting=self.nan_for_nonexisting,\n                                                               **metric_kwargs)\n\n        else:\n\n            for i, l in enumerate(self.labels):\n                self.result[l] = {}\n                self.confusion_matrix.set_test(self.test == l)\n                self.confusion_matrix.set_reference(self.reference == l)\n                for metric in eval_metrics:\n                    self.result[l][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,\n                                                            nan_for_nonexisting=self.nan_for_nonexisting,\n                                                            **metric_kwargs)\n\n        return self.result\n\n    def to_dict(self):\n\n        if self.result is None:\n            self.evaluate()\n        return self.result\n\n    def to_array(self):\n        \"\"\"Return result as numpy array (labels x metrics).\"\"\"\n\n        if self.result is None:\n            self.evaluate\n\n        result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())\n\n        a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32)\n\n        if isinstance(self.labels, dict):\n            for i, label in enumerate(self.labels.keys()):\n                for j, metric in enumerate(result_metrics):\n                    a[i][j] = self.result[self.labels[label]][metric]\n        else:\n            for i, label in enumerate(self.labels):\n                for j, metric in enumerate(result_metrics):\n                    a[i][j] = self.result[label][metric]\n\n        return a\n\n    def to_pandas(self):\n        \"\"\"Return result as pandas DataFrame.\"\"\"\n\n        a = self.to_array()\n\n        if isinstance(self.labels, dict):\n            labels = list(self.labels.values())\n        else:\n            labels = self.labels\n\n        result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())\n\n        return pd.DataFrame(a, index=labels, columns=result_metrics)\n\n\nclass NiftiEvaluator(Evaluator):\n\n    def __init__(self, *args, **kwargs):\n\n        self.test_nifti = None\n        self.reference_nifti = None\n        super(NiftiEvaluator, self).__init__(*args, **kwargs)\n\n    def set_test(self, test):\n        \"\"\"Set the test segmentation.\"\"\"\n\n        if test is not None:\n            self.test_nifti = sitk.ReadImage(test)\n            super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti))\n        else:\n            self.test_nifti = None\n            super(NiftiEvaluator, self).set_test(test)\n\n    def set_reference(self, reference):\n        \"\"\"Set the reference segmentation.\"\"\"\n\n        if reference is not None:\n            self.reference_nifti = sitk.ReadImage(reference)\n            super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti))\n        else:\n            self.reference_nifti = None\n            super(NiftiEvaluator, self).set_reference(reference)\n\n    def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs):\n\n        if voxel_spacing is None:\n            voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1]\n            metric_kwargs[\"voxel_spacing\"] = voxel_spacing\n\n        return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs)\n\n\ndef aggregate_scores(test_ref_pairs,\n                     evaluator=NiftiEvaluator,\n                     labels=None,\n                     nanmean=True,\n                     json_output_file=None,\n                     json_name=\"\",\n                     json_description=\"\",\n                     json_author=\"Fabian\",\n                     json_task=\"\",\n                     **metric_kwargs):\n    \"\"\"\n    test = predicted image\n    :param test_ref_pairs:\n    :param evaluator:\n    :param labels: must be a dict of int-> str or a list of int\n    :param nanmean:\n    :param json_output_file:\n    :param json_name:\n    :param json_description:\n    :param json_author:\n    :param json_task:\n    :param metric_kwargs:\n    :return:\n    \"\"\"\n\n    if type(evaluator) == type:\n        evaluator = evaluator()\n\n    if labels is not None:\n        evaluator.set_labels(labels)\n\n    all_scores = {}\n    all_scores[\"all\"] = []\n    all_scores[\"mean\"] = {}\n\n    for i, (test, ref) in enumerate(test_ref_pairs):\n\n        # evaluate\n        evaluator.set_test(test)\n        evaluator.set_reference(ref)\n        if evaluator.labels is None:\n            evaluator.construct_labels()\n        current_scores = evaluator.evaluate(**metric_kwargs)\n        if type(test) == str:\n            current_scores[\"test\"] = test\n        if type(ref) == str:\n            current_scores[\"reference\"] = ref\n        all_scores[\"all\"].append(current_scores)\n\n        # append score list for mean\n        for label, score_dict in current_scores.items():\n            if label in (\"test\", \"reference\"):\n                continue\n            if label not in all_scores[\"mean\"]:\n                all_scores[\"mean\"][label] = {}\n            for score, value in score_dict.items():\n                if score not in all_scores[\"mean\"][label]:\n                    all_scores[\"mean\"][label][score] = []\n                all_scores[\"mean\"][label][score].append(value)\n\n    for label in all_scores[\"mean\"]:\n        for score in all_scores[\"mean\"][label]:\n            if nanmean:\n                all_scores[\"mean\"][label][score] = float(np.nanmean(all_scores[\"mean\"][label][score]))\n            else:\n                all_scores[\"mean\"][label][score] = float(np.mean(all_scores[\"mean\"][label][score]))\n\n    # save to file if desired\n    # we create a hopefully unique id by hashing the entire output dictionary\n    if json_output_file is not None:\n        if type(json_output_file) == str:\n            json_output_file = open(json_output_file, \"w\")\n        json_dict = {}\n        json_dict[\"name\"] = json_name\n        json_dict[\"description\"] = json_description\n        timestamp = datetime.today()\n        json_dict[\"timestamp\"] = str(timestamp)\n        json_dict[\"task\"] = json_task\n        json_dict[\"author\"] = json_author\n        json_dict[\"results\"] = all_scores\n        json_dict[\"id\"] = hashlib.md5(json.dumps(json_dict).encode(\"utf-8\")).hexdigest()[:12]\n        json.dump(json_dict, json_output_file, indent=4, separators=(\",\", \": \"))\n        json_output_file.close()\n\n    return all_scores\n\n\ndef aggregate_scores_for_experiment(score_file,\n                                    labels=None,\n                                    metrics=Evaluator.default_metrics,\n                                    nanmean=True,\n                                    json_output_file=None,\n                                    json_name=\"\",\n                                    json_description=\"\",\n                                    json_author=\"Fabian\",\n                                    json_task=\"\"):\n\n    scores = np.load(score_file)\n    scores_mean = scores.mean(0)\n    if labels is None:\n        labels = list(map(str, range(scores.shape[1])))\n\n    results = []\n    results_mean = {}\n    for i in range(scores.shape[0]):\n        results.append({})\n        for l, label in enumerate(labels):\n            results[-1][label] = {}\n            results_mean[label] = {}\n            for m, metric in enumerate(metrics):\n                results[-1][label][metric] = float(scores[i][l][m])\n                results_mean[label][metric] = float(scores_mean[l][m])\n\n    json_dict = {}\n    json_dict[\"name\"] = json_name\n    json_dict[\"description\"] = json_description\n    timestamp = datetime.today()\n    json_dict[\"timestamp\"] = str(timestamp)\n    json_dict[\"task\"] = json_task\n    json_dict[\"author\"] = json_author\n    json_dict[\"results\"] = {\"all\": results, \"mean\": results_mean}\n    json_dict[\"id\"] = hashlib.md5(json.dumps(json_dict).encode(\"utf-8\")).hexdigest()[:12]\n    if json_output_file is not None:\n        json_output_file = open(json_output_file, \"w\")\n        json.dump(json_dict, json_output_file, indent=4, separators=(\",\", \": \"))\n        json_output_file.close()\n\n    return json_dict\n"
  },
  {
    "path": "evaluation/metrics.py",
    "content": "import numpy as np\nfrom medpy import metric\n\n\ndef assert_shape(test, reference):\n\n    assert test.shape == reference.shape, \"Shape mismatch: {} and {}\".format(\n        test.shape, reference.shape)\n\n\nclass ConfusionMatrix:\n\n    def __init__(self, test=None, reference=None):\n\n        self.tp = None\n        self.fp = None\n        self.tn = None\n        self.fn = None\n        self.size = None\n        self.reference_empty = None\n        self.reference_full = None\n        self.test_empty = None\n        self.test_full = None\n        self.set_reference(reference)\n        self.set_test(test)\n\n    def set_test(self, test):\n\n        self.test = test\n        self.reset()\n\n    def set_reference(self, reference):\n\n        self.reference = reference\n        self.reset()\n\n    def reset(self):\n\n        self.tp = None\n        self.fp = None\n        self.tn = None\n        self.fn = None\n        self.size = None\n        self.test_empty = None\n        self.test_full = None\n        self.reference_empty = None\n        self.reference_full = None\n\n    def compute(self):\n\n        if self.test is None or self.reference is None:\n            raise ValueError(\"'test' and 'reference' must both be set to compute confusion matrix.\")\n\n        assert_shape(self.test, self.reference)\n\n        self.tp = int(((self.test != 0) * (self.reference != 0)).sum())\n        self.fp = int(((self.test != 0) * (self.reference == 0)).sum())\n        self.tn = int(((self.test == 0) * (self.reference == 0)).sum())\n        self.fn = int(((self.test == 0) * (self.reference != 0)).sum())\n        self.size = int(np.product(self.reference.shape))\n        self.test_empty = not np.any(self.test)\n        self.test_full = np.all(self.test)\n        self.reference_empty = not np.any(self.reference)\n        self.reference_full = np.all(self.reference)\n\n    def get_matrix(self):\n\n        for entry in (self.tp, self.fp, self.tn, self.fn):\n            if entry is None:\n                self.compute()\n                break\n\n        return self.tp, self.fp, self.tn, self.fn\n\n    def get_size(self):\n\n        if self.size is None:\n            self.compute()\n        return self.size\n\n    def get_existence(self):\n\n        for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):\n            if case is None:\n                self.compute()\n                break\n\n        return self.test_empty, self.test_full, self.reference_empty, self.reference_full\n\n\ndef dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"2TP / (2TP + FP + FN)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty and reference_empty:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(2. * tp / (2 * tp + fp + fn))\n\n\ndef jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TP / (TP + FP + FN)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty and reference_empty:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(tp / (tp + fp + fn))\n\n\ndef precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TP / (TP + FP)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(tp / (tp + fp))\n\n\ndef sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TP / (TP + FN)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if reference_empty:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(tp / (tp + fn))\n\n\ndef recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TP / (TP + FN)\"\"\"\n\n    return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)\n\n\ndef specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TN / (TN + FP)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if reference_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(tn / (tn + fp))\n\n\ndef accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):\n    \"\"\"(TP + TN) / (TP + FP + FN + TN)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n\n    return float((tp + tn) / (tp + fp + tn + fn))\n\n\ndef fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):\n    \"\"\"(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)\"\"\"\n\n    precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)\n    recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)\n\n    return (1 + beta*beta) * precision_ * recall_ /\\\n        ((beta*beta * precision_) + recall_)\n\n\ndef false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"FP / (FP + TN)\"\"\"\n\n    return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)\n\n\ndef false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"FN / (TN + FN)\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0.\n\n    return float(fn / (fn + tn))\n\n\ndef false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"FN / (TP + FN)\"\"\"\n\n    return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)\n\n\ndef true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TN / (TN + FP)\"\"\"\n\n    return specificity(test, reference, confusion_matrix, nan_for_nonexisting)\n\n\ndef false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"FP / (TP + FP)\"\"\"\n\n    return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)\n\n\ndef negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):\n    \"\"\"TN / (TN + FN)\"\"\"\n\n    return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)\n\n\ndef total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):\n    \"\"\"TP + FP\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n\n    return tp + fp\n\n\ndef total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):\n    \"\"\"TN + FN\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n\n    return tn + fn\n\n\ndef total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):\n    \"\"\"TP + FN\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n\n    return tp + fn\n\n\ndef total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):\n    \"\"\"TN + FP\"\"\"\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    tp, fp, tn, fn = confusion_matrix.get_matrix()\n\n    return tn + fp\n\n\ndef hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty or test_full or reference_empty or reference_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0\n\n    test, reference = confusion_matrix.test, confusion_matrix.reference\n\n    return metric.hd(test, reference, voxel_spacing, connectivity)\n\n\ndef hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty or test_full or reference_empty or reference_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0\n\n    test, reference = confusion_matrix.test, confusion_matrix.reference\n\n    return metric.hd95(test, reference, voxel_spacing, connectivity)\n\n\ndef avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty or test_full or reference_empty or reference_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0\n\n    test, reference = confusion_matrix.test, confusion_matrix.reference\n\n    return metric.asd(test, reference, voxel_spacing, connectivity)\n\n\ndef avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):\n\n    if confusion_matrix is None:\n        confusion_matrix = ConfusionMatrix(test, reference)\n\n    test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()\n\n    if test_empty or test_full or reference_empty or reference_full:\n        if nan_for_nonexisting:\n            return float(\"NaN\")\n        else:\n            return 0\n\n    test, reference = confusion_matrix.test, confusion_matrix.reference\n\n    return metric.assd(test, reference, voxel_spacing, connectivity)\n\n\nALL_METRICS = {\n    \"False Positive Rate\": false_positive_rate,\n    \"Dice\": dice,\n    \"Jaccard\": jaccard,\n    \"Hausdorff Distance\": hausdorff_distance,\n    \"Hausdorff Distance 95\": hausdorff_distance_95,\n    \"Precision\": precision,\n    \"Recall\": recall,\n    \"Avg. Symmetric Surface Distance\": avg_surface_distance_symmetric,\n    \"Avg. Surface Distance\": avg_surface_distance,\n    \"Accuracy\": accuracy,\n    \"False Omission Rate\": false_omission_rate,\n    \"Negative Predictive Value\": negative_predictive_value,\n    \"False Negative Rate\": false_negative_rate,\n    \"True Negative Rate\": true_negative_rate,\n    \"False Discovery Rate\": false_discovery_rate,\n    \"Total Positives Test\": total_positives_test,\n    \"Total Negatives Test\": total_negatives_test,\n    \"Total Positives Reference\": total_positives_reference,\n    \"total Negatives Reference\": total_negatives_reference\n}\n"
  },
  {
    "path": "evaluation/readme.md",
    "content": "# Evaluation Suite\n\n### Metrics\n\nAll metrics can be used either by passing test and reference segmentations as\nparameters or by passing a `ConfusionMatrix` object. The latter is useful when many\nmetrics need to be computed, because the relevant computations are only done once.\nAll metrics assume binary segmentation inputs.\n\n`ConfusionMatrix` has two important methods: `.get_matrix()`, which returns 4 ints for true positives, false positives, true negatives and false negatives, and\n`.get_existence()`, which returns 4 bools, indicating whether test and reference\nsegmentations are all ones or all zeros. The latter is used when you specify\n`nan_for_nonexisting=True` in metric calls to return NaN instead of 0 when the result\nis undefined, i.e. would require dividing by 0.\n\n### Evaluator\n\nThe `Evaluator` is a class that holds one test and one reference segmentation at a time that can contain multiple labels (one-hot encoding is not supported). It also holds a labels attribute than can either be a list of ints (or tuples of ints) or a dictionary\nthat maps from ints (or tuples of ints) to label names. A typical labels dictionary\ncould look like this:\n\n```python\nlabels = {\n    1: \"Edema\",\n    2: \"Enhancing Tumor\",\n    3: \"Necrosis\",\n    (1, 2, 3): \"Whole Tumor\"\n}\n```\n\nLabels in a tuple will be joined. If no labels are set, they will be automatically constructed from the unique entries in the segmentations upon evaluation. The Evaluator has both a regular set of metrics\nthat will always be computed and a set of advanced metrics that will only be computed\nif `.evaluate(advanced=True)` is passed. The `.evaluate()` method is designed to\nlook for metric definitions in the current frame, so when you work in an interactive shell and redefine something there (e.g. for testing purposes), the newly defined metric will be used. You can also pass test and reference segmentations directly to evaluate to save calls to `.set_test()` and `.set_reference()`. `.evaluate()` will return a result dictionary and also save it in the `.result` attribute, so you can call `.to_array()` (numpy) or `.to_pandas()` (pandas) later. The resulting shape will be (labels x metrics). `.evaluate()` also takes additional `**metric_kwargs` that will be passed to each metric call.\n\n### NiftiEvaluator\n\n`NiftiEvaluator` redefines the `.set_test()` and `.set_reference()` methods of the `Evaluator` to take path strings instead of arrays. It will read the NIfTI files using SimpleITK, save the SimpleITK images in the `.test_nifti` and `.reference_nifti` attributes and set the arrays as test and reference segmentations. `.evaluate()` has an additional parameter `voxel_spacing`, which should be an iterable of floats. If the parameter is None, the spacing will be automatically read from the SimpleITK images. If you manually read the spacing from SimpleITK images, note that you have to reverse the ordering, because SimpleITK will return (z,y,x) ordering while we expect (x,y,z).\n\n### Evaluating multiple segmentations\n\nIf you want to evaluate multiple test/reference pairs and get aggregate statistics, use the `aggregate_scores()` function. It expects an iterable of test/reference pairs and an evaluator (instance or type, will automatically initialize if necessary), which is the `NiftiEvaluator` by default. Test and reference will be set via `.set_test()` and `.set_reference()`, so make sure you're passing the right type for the evaluator. The method will return a dictionary that contains a list of all separate results as well as their mean:\n\n```python\nresults = {\n    \"all\": [\n        {\n            \"Label\": {\n                \"Metric\": float,\n                \"Metric\": float,\n                ...\n            },\n            \"Label\": {\n                ...\n            },\n            ...\n        },\n        {\n            \"Label\": ...,\n            \"Label\": ...,\n            ...\n        },\n        ...\n    ],\n    \"mean\": {\n        \"Label\": ...\n        \"Label\": ...,\n        ...\n    }\n}\n```\n\n`nanmean=True` will use `np.nanmean` instead of `np.mean`. It should be easy to adjust the code to compute arbitrary statistics, but at the moment only mean is supported. If you specify a `json_output_file`, a json file will be written that contains the result dictionary as well as additional information you can specify using the other `json_*` parameters:\n\n```python\njson = {\n    \"name\": json_name,  # experiment name, not yours\n    \"description\": json_description,  # a longer description so you know what you did\n    \"timestamp\": \"YYYY-MM-DD hh:mm:ss.ffffff\",  # automatically generated\n    \"task\": json_task  # the decathlon task\n    \"author\": json_author  # probably Fabian :)\n    \"results\": ...  # the above dictionary\n    \"id\": 001122334455  # hash of other entries as unique id\n}\n```\n\n`labels` is passed to the evaluator and `**metric_kwargs` is passed to all `.evaluate()` calls."
  },
  {
    "path": "experiments/UNetExperiment.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport pickle\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nimport torch.nn.functional as F\n\nfrom datasets.two_dim.NumpyDataLoader import NumpyDataSet\nfrom trixi.experiment.pytorchexperiment import PytorchExperiment\n\nfrom networks.UNET import UNet\nfrom loss_functions.dice_loss import SoftDiceLoss\n\n\nclass UNetExperiment(PytorchExperiment):\n    \"\"\"\n    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).\n    It is optimized to work with the provided NumpyDataLoader.\n\n    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:\n\n        setup()\n        (--> Automatically restore values if a previous checkpoint is given)\n        prepare()\n\n        for epoch in n_epochs:\n            train()\n            validate()\n            (--> save current checkpoint)\n\n        end()\n    \"\"\"\n\n    def setup(self):\n        pkl_dir = self.config.split_dir\n        with open(os.path.join(pkl_dir, \"splits.pkl\"), 'rb') as f:\n            splits = pickle.load(f)\n\n        tr_keys = splits[self.config.fold]['train']\n        val_keys = splits[self.config.fold]['val']\n        test_keys = splits[self.config.fold]['test']\n\n        self.device = torch.device(self.config.device if torch.cuda.is_available() else \"cpu\")\n\n        self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                              keys=tr_keys)\n        self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                            keys=val_keys, mode=\"val\", do_reshuffle=False)\n        self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                             keys=test_keys, mode=\"test\", do_reshuffle=False)\n        self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)\n\n        self.model.to(self.device)\n\n        # We use a combination of DICE-loss and CE-Loss in this example.\n        # This proved good in the medical segmentation decathlon.\n        self.dice_loss = SoftDiceLoss(batch_dice=True)  # Softmax for DICE Loss!\n        self.ce_loss = torch.nn.CrossEntropyLoss()  # No softmax for CE Loss -> is implemented in torch!\n\n        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)\n        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')\n\n        # If directory for checkpoint is provided, we load it.\n        if self.config.do_load_checkpoint:\n            if self.config.checkpoint_dir == '':\n                print('checkpoint_dir is empty, please provide directory to load checkpoint.')\n            else:\n                self.load_checkpoint(name=self.config.checkpoint_dir, save_types=(\"model\",))\n\n        self.save_checkpoint(name=\"checkpoint_start\")\n        self.elog.print('Experiment set up.')\n\n    def train(self, epoch):\n        self.elog.print('=====TRAIN=====')\n        self.model.train()\n\n        data = None\n        batch_counter = 0\n        for data_batch in self.train_data_loader:\n\n            self.optimizer.zero_grad()\n\n            # Shape of data_batch = [1, b, c, w, h]\n            # Desired shape = [b, c, w, h]\n            # Move data and target to the GPU\n            data = data_batch['data'][0].float().to(self.device)\n            target = data_batch['seg'][0].long().to(self.device)\n\n            pred = self.model(data)\n            pred_softmax = F.softmax(pred, dim=1)  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.\n\n            loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())\n            # loss = self.ce_loss(pred, target.squeeze())\n\n            loss.backward()\n            self.optimizer.step()\n\n            # Some logging and plotting\n            if (batch_counter % self.config.plot_freq) == 0:\n                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(self._epoch_idx, loss))\n\n                self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))\n\n                self.clog.show_image_grid(data.float().cpu(), name=\"data\", normalize=True, scale_each=True, n_iter=epoch)\n                self.clog.show_image_grid(target.float().cpu(), name=\"mask\", title=\"Mask\", n_iter=epoch)\n                self.clog.show_image_grid(pred.cpu()[:, 1:2, ], name=\"unt\", normalize=True, scale_each=True, n_iter=epoch)\n\n            batch_counter += 1\n\n        assert data is not None, 'data is None. Please check if your dataloader works properly'\n\n    def validate(self, epoch):\n        self.elog.print('VALIDATE')\n        self.model.eval()\n\n        data = None\n        loss_list = []\n\n        with torch.no_grad():\n            for data_batch in self.val_data_loader:\n                data = data_batch['data'][0].float().to(self.device)\n                target = data_batch['seg'][0].long().to(self.device)\n\n                pred = self.model(data)\n                pred_softmax = F.softmax(pred, dim=1)  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.\n\n                loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())\n                loss_list.append(loss.item())\n\n        assert data is not None, 'data is None. Please check if your dataloader works properly'\n        self.scheduler.step(np.mean(loss_list))\n\n        self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))\n\n        self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)\n\n        self.clog.show_image_grid(data.float().cpu(), name=\"data_val\", normalize=True, scale_each=True, n_iter=epoch)\n        self.clog.show_image_grid(target.float().cpu(), name=\"mask_val\", title=\"Mask\", n_iter=epoch)\n        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ], name=\"unt_val\", normalize=True, scale_each=True, n_iter=epoch)\n\n    def test(self):\n        from evaluation.evaluator import aggregate_scores, Evaluator\n        from collections import defaultdict\n\n        self.elog.print('=====TEST=====')\n        self.model.eval()\n\n        pred_dict = defaultdict(list)\n        gt_dict = defaultdict(list)\n\n        batch_counter = 0\n        with torch.no_grad():\n            for data_batch in self.test_data_loader:\n                print('testing...', batch_counter)\n                batch_counter += 1\n\n                # Get data_batches\n                mr_data = data_batch['data'][0].float().to(self.device)\n                mr_target = data_batch['seg'][0].float().to(self.device)\n\n                pred = self.model(mr_data)\n                pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)\n\n                fnames = data_batch['fnames']\n                for i, fname in enumerate(fnames):\n                    pred_dict[fname[0]].append(pred_argmax[i].detach().cpu().numpy())\n                    gt_dict[fname[0]].append(mr_target[i].detach().cpu().numpy())\n\n        test_ref_list = []\n        for key in pred_dict.keys():\n            test_ref_list.append((np.stack(pred_dict[key]), np.stack(gt_dict[key])))\n\n        scores = aggregate_scores(test_ref_list, evaluator=Evaluator, json_author=self.config.author, json_task=self.config.name, json_name=self.config.name,\n                                  json_output_file=self.elog.work_dir + \"/{}_\".format(self.config.author) + self.config.name + '.json')\n\n        print(\"Scores:\\n\", scores)\n\n    def segment_single_image(self, data):\n        self.model = UNet(num_classes=self.config.num_classes, in_channels=self.config.in_channels)\n        self.device = torch.device(self.config.device if torch.cuda.is_available() else \"cpu\")\n\n        # a model must be present and loaded in here\n        if self.config.model_dir == '':\n            print('model_dir is empty, please provide directory to load checkpoint.')\n        else:\n            self.load_checkpoint(name=self.config.model_dir, save_types=(\"model\",))\n\n        self.elog.print(\"=====SEGMENT_SINGLE_IMAGE=====\")\n        self.model.eval()\n        self.model.to(self.device)\n\n        # Desired shape = [b, c, w, h]\n        # split into even chunks (lets use size)\n        with torch.no_grad():\n\n            ######\n            # When working entirely on CPU and in memory, the following lines replace the split/concat method\n            # mr_data = data.float().to(self.device)\n            # pred = self.model(mr_data)\n            # pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)\n            ######\n\n            ######\n            # for CUDA (also works on CPU) split into batches\n            blocksize = self.config.batch_size\n\n            # number_of_elements = round(data.shape[0]/blocksize+0.5)     # make blocks large enough to not lose any slices\n            chunks = [data[i:i+blocksize, ::, ::, ::] for i in range(0, data.shape[0], blocksize)]\n            pred_list = []\n            for data_batch in chunks:\n                mr_data = data_batch.float().to(self.device)\n                pred_dict = self.model(mr_data)\n                pred_list.append(pred_dict.cpu())\n\n            pred = torch.Tensor(np.concatenate(pred_list))\n            pred_argmax = torch.argmax(pred, dim=1, keepdim=True)\n\n        # detach result and put it back to cpu so that we can work with, create a numpy array\n        result = pred_argmax.short().detach().cpu().numpy()\n\n        return result\n"
  },
  {
    "path": "experiments/UNetExperiment3D.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport pickle\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nimport torch.optim as optim\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nfrom datasets.three_dim.NumpyDataLoader import NumpyDataSet\nfrom trixi.experiment.pytorchexperiment import PytorchExperiment\n\nfrom networks.RecursiveUNet3D import UNet3D\nfrom loss_functions.dice_loss import SoftDiceLoss, DC_and_CE_loss\n\n\nclass UNetExperiment3D(PytorchExperiment):\n    \"\"\"\n    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).\n    It is optimized to work with the provided NumpyDataLoader.\n\n    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:\n\n        setup()\n        (--> Automatically restore values if a previous checkpoint is given)\n        prepare()\n\n        for epoch in n_epochs:\n            train()\n            validate()\n            (--> save current checkpoint)\n\n        end()\n    \"\"\"\n\n    def setup(self):\n        pkl_dir = self.config.split_dir\n        with open(os.path.join(pkl_dir, \"splits.pkl\"), 'rb') as f:\n            splits = pickle.load(f)\n\n        tr_keys = splits[self.config.fold]['train']\n        val_keys = splits[self.config.fold]['val']\n        test_keys = splits[self.config.fold]['test']\n\n        self.device = torch.device(self.config.device if torch.cuda.is_available() else \"cpu\")\n\n        self.train_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                              keys=tr_keys)\n        self.val_data_loader = NumpyDataSet(self.config.data_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                            keys=val_keys, mode=\"val\", do_reshuffle=False)\n        self.test_data_loader = NumpyDataSet(self.config.data_test_dir, target_size=self.config.patch_size, batch_size=self.config.batch_size,\n                                             keys=test_keys, mode=\"test\", do_reshuffle=False)\n        self.model = UNet3D(num_classes=3, in_channels=1)\n\n        self.model.to(self.device)\n\n        # We use a combination of DICE-loss and CE-Loss in this example.\n        # This proved good in the medical segmentation decathlon.\n        self.loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,\n                                    'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, OrderedDict())\n\n        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)\n        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')\n\n        # If directory for checkpoint is provided, we load it.\n        if self.config.do_load_checkpoint:\n            if self.config.checkpoint_dir == '':\n                print('checkpoint_dir is empty, please provide directory to load checkpoint.')\n            else:\n                self.load_checkpoint(name=self.config.checkpoint_dir, save_types=(\"model\",))\n\n        self.save_checkpoint(name=\"checkpoint_start\")\n        self.elog.print('Experiment set up.')\n\n    def train(self, epoch):\n        self.elog.print('=====TRAIN=====')\n        self.model.train()\n\n        batch_counter = 0\n        for data_batch in self.train_data_loader:\n\n            self.optimizer.zero_grad()\n\n            # Shape of data_batch = [1, b, c, w, h]\n            # Desired shape = [b, c, w, h]\n            # Move data and target to the GPU\n            data = data_batch['data'][0].float().to(self.device)\n            target = data_batch['seg'][0].long().to(self.device)\n\n            pred = self.model(data)\n\n            loss = self.loss(pred, target.squeeze())\n            # loss = self.ce_loss(pred, target.squeeze())\n            loss.backward()\n            self.optimizer.step()\n\n            # Some logging and plotting\n            if (batch_counter % self.config.plot_freq) == 0:\n                self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, loss))\n\n                self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.data_loader.num_batches))\n\n                self.clog.show_image_grid(data[:,:,30].float(), name=\"data\", normalize=True, scale_each=True, n_iter=epoch)\n                self.clog.show_image_grid(target[:,:,30].float(), name=\"mask\", title=\"Mask\", n_iter=epoch)\n                self.clog.show_image_grid(torch.argmax(pred.cpu(), dim=1, keepdim=True)[:,:,30], name=\"unt_argmax\", title=\"Unet\", n_iter=epoch)\n\n            batch_counter += 1\n\n    def validate(self, epoch):\n        if epoch % 5 != 0:\n            return\n        self.elog.print('VALIDATE')\n        self.model.eval()\n\n        data = None\n        loss_list = []\n\n        with torch.no_grad():\n            for data_batch in self.val_data_loader:\n                data = data_batch['data'][0].float().to(self.device)\n                target = data_batch['seg'][0].long().to(self.device)\n\n                pred = self.model(data)\n\n                loss = self.loss(pred, target.squeeze())\n                loss_list.append(loss.item())\n\n        assert data is not None, 'data is None. Please check if your dataloader works properly'\n        self.scheduler.step(np.mean(loss_list))\n\n        self.elog.print('Epoch: %d Loss: %.4f' % (self._epoch_idx, float(np.mean(loss_list))))\n\n        self.add_result(value=np.mean(loss_list), name='Val_Loss', tag='Loss', counter=epoch+1)\n\n        self.clog.show_image_grid(data[:,:,30].float(), name=\"data_val\", normalize=True, scale_each=True, n_iter=epoch)\n        self.clog.show_image_grid(target[:,:,30].float(), name=\"mask_val\", title=\"Mask\", n_iter=epoch)\n        self.clog.show_image_grid(torch.argmax(pred.data.cpu()[:,:,30], dim=1, keepdim=True), name=\"unt_argmax_val\", title=\"Unet\", n_iter=epoch)\n\n    def test(self):\n        # TODO\n        print('TODO: Implement your test() method here')\n"
  },
  {
    "path": "experiments/__init__.py",
    "content": ""
  },
  {
    "path": "loss_functions/ND_Crossentropy.py",
    "content": "import torch\n\n\nclass CrossentropyND(torch.nn.CrossEntropyLoss):\n    \"\"\"\n    Network has to have NO NONLINEARITY!\n    \"\"\"\n    def forward(self, inp, target):\n        target = target.long()\n        num_classes = inp.size()[1]\n\n        i0 = 1\n        i1 = 2\n\n        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once\n            inp = inp.transpose(i0, i1)\n            i0 += 1\n            i1 += 1\n\n        inp = inp.contiguous()\n        inp = inp.view(-1, num_classes)\n\n        target = target.view(-1,)\n\n        return super(CrossentropyND, self).forward(inp, target)"
  },
  {
    "path": "loss_functions/__init__.py",
    "content": ""
  },
  {
    "path": "loss_functions/dice_loss.py",
    "content": "import torch\nimport numpy as np\nfrom loss_functions.ND_Crossentropy import CrossentropyND\nfrom loss_functions.topk_loss import TopKLoss\nfrom torch import nn\n\n\ndef softmax_helper(x):\n    rpt = [1 for _ in range(len(x.size()))]\n    rpt[1] = x.size(1)\n    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)\n    e_x = torch.exp(x - x_max)\n    return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)\n\n\ndef get_tp_fp_fn(net_output, gt, axes=None, mask=None):\n    \"\"\"\n    net_output must be (b, c, x, y(, z)))\n    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))\n    if mask is provided it must have shape (b, 1, x, y(, z)))\n    :param net_output:\n    :param gt:\n    :param axes:\n    :param mask:\n    :return:\n    \"\"\"\n    if axes is None:\n        axes = tuple(range(2, len(net_output.size())))\n\n    shp_x = net_output.shape\n    shp_y = gt.shape\n\n    with torch.no_grad():\n        if len(shp_x) != len(shp_y):\n            gt = gt.view((shp_y[0], 1, *shp_y[1:]))\n\n        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):\n            # if this is the case then gt is probably already a one hot encoding\n            y_onehot = gt\n        else:\n            gt = gt.long()\n            y_onehot = torch.zeros(shp_x)\n            if net_output.device.type == \"cuda\":\n                y_onehot = y_onehot.cuda(net_output.device.index)\n            y_onehot.scatter_(1, gt, 1)\n\n    tp = net_output * y_onehot\n    fp = net_output * (1 - y_onehot)\n    fn = (1 - net_output) * y_onehot\n\n    if mask is not None:\n        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)\n        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)\n        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)\n\n    tp = sum_tensor(tp, axes, keepdim=False)\n    fp = sum_tensor(fp, axes, keepdim=False)\n    fn = sum_tensor(fn, axes, keepdim=False)\n\n    return tp, fp, fn\n\n\ndef sum_tensor(inp, axes, keepdim=False):\n    axes = np.unique(axes).astype(int)\n    if keepdim:\n        for ax in axes:\n            inp = inp.sum(int(ax), keepdim=True)\n    else:\n        for ax in sorted(axes, reverse=True):\n            inp = inp.sum(int(ax))\n    return inp\n\n\ndef mean_tensor(inp, axes, keepdim=False):\n    axes = np.unique(axes).astype(int)\n    if keepdim:\n        for ax in axes:\n            inp = inp.mean(int(ax), keepdim=True)\n    else:\n        for ax in sorted(axes, reverse=True):\n            inp = inp.mean(int(ax))\n    return inp\n\n\nclass SoftDiceLoss(nn.Module):\n    def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None):\n        \"\"\"\n        hahaa no documentation for you today\n        :param smooth:\n        :param apply_nonlin:\n        :param batch_dice:\n        :param do_bg:\n        :param smooth_in_nom:\n        :param background_weight:\n        :param rebalance_weights:\n        \"\"\"\n        super(SoftDiceLoss, self).__init__()\n        if not do_bg:\n            assert background_weight == 1, \"if there is no bg, then set background weight to 1 you dummy\"\n        self.rebalance_weights = rebalance_weights\n        self.background_weight = background_weight\n        if smooth_in_nom:\n            self.smooth_in_nom = smooth\n        else:\n            self.smooth_in_nom = 0\n        self.do_bg = do_bg\n        self.batch_dice = batch_dice\n        self.apply_nonlin = apply_nonlin\n        self.smooth = smooth\n        self.y_onehot = None\n\n    def forward(self, x, y):\n        with torch.no_grad():\n            y = y.long()\n        shp_x = x.shape\n        shp_y = y.shape\n        if self.apply_nonlin is not None:\n            x = self.apply_nonlin(x)\n        if len(shp_x) != len(shp_y):\n            y = y.view((shp_y[0], 1, *shp_y[1:]))\n        # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively\n        y_onehot = torch.zeros(shp_x)\n        if x.device.type == \"cuda\":\n            y_onehot = y_onehot.cuda(x.device.index)\n        y_onehot.scatter_(1, y, 1)\n        if not self.do_bg:\n            x = x[:, 1:]\n            y_onehot = y_onehot[:, 1:]\n        if not self.batch_dice:\n            if self.background_weight != 1 or (self.rebalance_weights is not None):\n                raise NotImplementedError(\"nah son\")\n            l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom)\n        else:\n            l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom,\n                                      background_weight=self.background_weight,\n                                      rebalance_weights=self.rebalance_weights)\n        return l\n\n\ndef soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1):\n    axes = tuple([0] + list(range(2, len(net_output.size()))))\n    intersect = sum_tensor(net_output * gt, axes, keepdim=False)\n    denom = sum_tensor(net_output + gt, axes, keepdim=False)\n    weights = torch.ones(intersect.shape)\n    weights[0] = background_weight\n    if net_output.device.type == \"cuda\":\n        weights = weights.cuda(net_output.device.index)\n    result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean()\n    return result\n\n\ndef soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None):\n    if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:\n        rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False\n    axes = tuple([0] + list(range(2, len(net_output.size()))))\n    tp = sum_tensor(net_output * gt, axes, keepdim=False)\n    fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)\n    fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)\n    weights = torch.ones(tp.shape)\n    weights[0] = background_weight\n    if net_output.device.type == \"cuda\":\n        weights = weights.cuda(net_output.device.index)\n    if rebalance_weights is not None:\n        rebalance_weights = torch.from_numpy(rebalance_weights).float()\n        if net_output.device.type == \"cuda\":\n            rebalance_weights = rebalance_weights.cuda(net_output.device.index)\n        tp = tp * rebalance_weights\n        fn = fn * rebalance_weights\n    result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean()\n    return result\n\n\ndef soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.):\n    axes = tuple(range(2, len(net_output.size())))\n    intersect = sum_tensor(net_output * gt, axes, keepdim=False)\n    denom = sum_tensor(net_output + gt, axes, keepdim=False)\n    result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean()\n    return result\n\n\nclass MultipleOutputLoss(nn.Module):\n    def __init__(self, loss, weight_factors=None):\n        \"\"\"\n        use this if you have several outputs that should predict the same y\n        :param loss:\n        :param weight_factors:\n        \"\"\"\n        super(MultipleOutputLoss, self).__init__()\n        self.weight_factors = weight_factors\n        self.loss = loss\n\n    def forward(self, x, y):\n        assert isinstance(x, (tuple, list)), \"x must be either tuple or list\"\n        if self.weight_factors is None:\n            weights = [1] * len(x)\n        else:\n            weights = self.weight_factors\n        l = weights[0] * self.loss(x[0], y)\n        for i in range(1, len(x)):\n            l += weights[i] * self.loss(x[i], y)\n        return l\n\n\nclass DC_and_CE_loss(nn.Module):\n    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate=\"sum\"):\n        super(DC_and_CE_loss, self).__init__()\n        self.aggregate = aggregate\n        self.ce = CrossentropyND(**ce_kwargs)\n        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)\n\n    def forward(self, net_output, target):\n        dc_loss = self.dc(net_output, target)\n        ce_loss = self.ce(net_output, target)\n        if self.aggregate == \"sum\":\n            result = ce_loss + dc_loss\n        else:\n            raise NotImplementedError(\"nah son\") # reserved for other stuff (later)\n        return result\n\n\nclass DC_and_topk_loss(nn.Module):\n    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate=\"sum\"):\n        super(DC_and_topk_loss, self).__init__()\n        self.aggregate = aggregate\n        self.ce = TopKLoss(**ce_kwargs)\n        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)\n\n    def forward(self, net_output, target):\n        dc_loss = self.dc(net_output, target)\n        ce_loss = self.ce(net_output, target)\n        if self.aggregate == \"sum\":\n            result = ce_loss + dc_loss\n        else:\n            raise NotImplementedError(\"nah son\") # reserved for other stuff (later?)\n        return result\n\n\nclass CrossentropyWithLossMask(nn.CrossEntropyLoss):\n    def __init__(self, k=None):\n        \"\"\"\n        This implementation ignores weight, ignore_index (use loss mask!) and reduction!\n        :param k:\n        \"\"\"\n        super(CrossentropyWithLossMask, self).__init__(weight=None, ignore_index=-100, reduction='none')\n        self.k = k\n\n    def forward(self, inp, target, loss_mask=None):\n        target = target.long()\n        inp = inp.float()\n        if loss_mask is not None:\n            loss_mask = loss_mask.float()\n        num_classes = inp.size()[1]\n\n        i0 = 1\n        i1 = 2\n\n        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once\n            inp = inp.transpose(i0, i1)\n            i0 += 1\n            i1 += 1\n\n        if not inp.is_contiguous():\n            inp = inp.contiguous()\n\n        inp = inp.view(target.shape[0], -1, num_classes)\n\n        target = target.view(target.shape[0], -1)\n        if loss_mask is not None:\n            loss_mask = loss_mask.view(target.shape[0], -1)\n\n        if self.k is not None:\n            if loss_mask is not None:\n                num_sel = torch.stack(tuple([i.sum() / self.k for i in torch.unbind(loss_mask, 0)]), 0).long()\n                loss = torch.stack(tuple([\n                    torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()],\n                               num_sel[i], sorted=False)[0].mean()\n                    for i in range(target.shape[0])\n                ])\n                )\n            else:\n                num_sel = [np.prod(inp.shape[2:]) / self.k] * inp.shape[0]\n                loss = torch.stack(tuple([\n                    torch.topk(super(CrossentropyWithLossMask, self).forward(inp[i], target[i]),\n                               num_sel[i], sorted=False)[0].mean()\n                    for i in range(target.shape[0])\n                ])\n                )\n        else:\n            if loss_mask is not None:\n                loss = torch.stack(tuple([\n                    super(CrossentropyWithLossMask, self).forward(inp[i], target[i])[loss_mask[i].byte()].mean()\n                    for i in range(target.shape[0])\n                ])\n                )\n            else:\n                loss = torch.stack(tuple([\n                    super(CrossentropyWithLossMask, self).forward(inp[i], target[i]).mean()\n                    for i in range(target.shape[0])\n                ])\n                )\n\n        loss = loss.mean()\n        return loss\n"
  },
  {
    "path": "loss_functions/topk_loss.py",
    "content": "import numpy as np\nimport torch\nfrom loss_functions.ND_Crossentropy import CrossentropyND\n\n\nclass TopKLoss(CrossentropyND):\n    \"\"\"\n    Network has to have NO LINEARITY!\n    \"\"\"\n\n    def __init__(self, weight=None, ignore_index=-100, k=10):\n        self.k = k\n        super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)\n\n    def forward(self, inp, target):\n        target = target[:, 0].long()\n        res = super(TopKLoss, self).forward(inp, target)\n        num_voxels = np.prod(res.shape)\n        res, _ = torch.topk(res.view((-1,)), int(num_voxels // self.k), sorted=False)\n        return res.mean()"
  },
  {
    "path": "networks/RecursiveUNet.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Defines the Unet.\n# |num_downs|: number of downsamplings in UNet. For example,\n# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck\n\n# recursive implementation of Unet\nimport torch\n\nfrom torch import nn\n\n\nclass UNet(nn.Module):\n    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):\n        # norm_layer=nn.BatchNorm2d, use_dropout=False):\n        super(UNet, self).__init__()\n\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,\n                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)\n        for i in range(1, num_downs):\n            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),\n                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),\n                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,\n                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,\n                                             outermost=True)\n\n        self.model = unet_block\n\n    def forward(self, x):\n        return self.model(x)\n\n\n# Defines the submodule with skip connection.\n# X -------------------identity---------------------- X\n#   |-- downsampling -- |submodule| -- upsampling --|\nclass UnetSkipConnectionBlock(nn.Module):\n    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):\n        super(UnetSkipConnectionBlock, self).__init__()\n        self.outermost = outermost\n        # downconv\n        pool = nn.MaxPool2d(2, stride=2)\n        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)\n        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)\n\n        # upconv\n        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)\n        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)\n\n        if outermost:\n            final = nn.Conv2d(out_channels, num_classes, kernel_size=1)\n            down = [conv1, conv2]\n            up = [conv3, conv4, final]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose2d(in_channels*2, in_channels,\n                                        kernel_size=2, stride=2)\n            model = [pool, conv1, conv2, upconv]\n        else:\n            upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)\n\n            down = [pool, conv1, conv2]\n            up = [conv3, conv4, upconv]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    @staticmethod\n    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):\n        layer = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\n            norm_layer(out_channels),\n            nn.LeakyReLU(inplace=True))\n        return layer\n\n    @staticmethod\n    def expand(in_channels, out_channels, kernel_size=3):\n        layer = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\n            nn.LeakyReLU(inplace=True),\n        )\n        return layer\n\n    @staticmethod\n    def center_crop(layer, target_width, target_height):\n        batch_size, n_channels, layer_width, layer_height = layer.size()\n        xy1 = (layer_width - target_width) // 2\n        xy2 = (layer_height - target_height) // 2\n        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:\n            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])\n            return torch.cat([x, crop], 1)\n"
  },
  {
    "path": "networks/RecursiveUNet3D.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Defines the Unet.\n# |num_downs|: number of downsamplings in UNet. For example,\n# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck\n\n# recursive implementation of Unet\nimport torch\n\nfrom torch import nn\n\n\nclass UNet3D(nn.Module):\n    def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=3, norm_layer=nn.InstanceNorm3d):\n        # norm_layer=nn.BatchNorm2d, use_dropout=False):\n        super(UNet3D, self).__init__()\n\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,\n                                             num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)\n        for i in range(1, num_downs):\n            unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),\n                                                 out_channels=initial_filter_size * 2 ** (num_downs-i),\n                                                 num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,\n                                             num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,\n                                             outermost=True)\n\n        self.model = unet_block\n\n    def forward(self, x):\n        return self.model(x)\n\n\n# Defines the submodule with skip connection.\n# X -------------------identity---------------------- X\n#   |-- downsampling -- |submodule| -- upsampling --|\nclass UnetSkipConnectionBlock(nn.Module):\n    def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm3d, use_dropout=False):\n        super(UnetSkipConnectionBlock, self).__init__()\n        self.outermost = outermost\n        # downconv\n        pool = nn.MaxPool3d(2, stride=2)\n        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)\n        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)\n\n        # upconv\n        conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)\n        conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)\n\n        if outermost:\n            final = nn.Conv3d(out_channels, num_classes, kernel_size=1)\n            down = [conv1, conv2]\n            up = [conv3, conv4, final]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose3d(in_channels*2, in_channels,\n                                        kernel_size=2, stride=2)\n            model = [pool, conv1, conv2, upconv]\n        else:\n            upconv = nn.ConvTranspose3d(in_channels*2, in_channels, kernel_size=2, stride=2)\n\n            down = [pool, conv1, conv2]\n            up = [conv3, conv4, upconv]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    @staticmethod\n    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm3d):\n        layer = nn.Sequential(\n            nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),\n            norm_layer(out_channels),\n            nn.LeakyReLU(inplace=True))\n        return layer\n\n    @staticmethod\n    def expand(in_channels, out_channels, kernel_size=3):\n        layer = nn.Sequential(\n            nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),\n            nn.LeakyReLU(inplace=True),\n        )\n        return layer\n\n    @staticmethod\n    def center_crop(layer, target_depth, target_width, target_height):\n        batch_size, n_channels, layer_depth, layer_width, layer_height = layer.size()\n        xy0 = (layer_depth - target_depth) // 2\n        xy1 = (layer_width - target_width) // 2\n        xy2 = (layer_height - target_height) // 2\n        return layer[:, :, xy0:(xy0 + target_depth), xy1:(xy1 + target_width), xy2:(xy2 + target_height)]\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:\n            crop = self.center_crop(self.model(x), x.size()[2], x.size()[3], x.size()[4])\n            return torch.cat([x, crop], 1)\n"
  },
  {
    "path": "networks/UNET.py",
    "content": "#!/usr/bin/env python\r\n# -*- coding: utf-8 -*-\r\n#\r\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport torch\r\nimport torch.nn as nn\r\n\r\n\r\nclass UNet(nn.Module):\r\n\r\n    def __init__(self, num_classes, in_channels=1, initial_filter_size=64, kernel_size=3, do_instancenorm=True):\r\n        super().__init__()\r\n\r\n        self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm)\r\n        self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, instancenorm=do_instancenorm)\r\n        self.pool = nn.MaxPool2d(2, stride=2)\r\n\r\n        self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)\r\n        self.contr_2_2 = self.contract(initial_filter_size*2, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm)\r\n        # self.pool2 = nn.MaxPool2d(2, stride=2)\r\n\r\n        self.contr_3_1 = self.contract(initial_filter_size*2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)\r\n        self.contr_3_2 = self.contract(initial_filter_size*2**2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm)\r\n        # self.pool3 = nn.MaxPool2d(2, stride=2)\r\n\r\n        self.contr_4_1 = self.contract(initial_filter_size*2**2, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)\r\n        self.contr_4_2 = self.contract(initial_filter_size*2**3, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm)\r\n        # self.pool4 = nn.MaxPool2d(2, stride=2)\r\n\r\n        self.center = nn.Sequential(\r\n            nn.Conv2d(initial_filter_size*2**3, initial_filter_size*2**4, 3, padding=1),\r\n            nn.ReLU(inplace=True),\r\n            nn.Conv2d(initial_filter_size*2**4, initial_filter_size*2**4, 3, padding=1),\r\n            nn.ReLU(inplace=True),\r\n            nn.ConvTranspose2d(initial_filter_size*2**4, initial_filter_size*2**3, 2, stride=2),\r\n            nn.ReLU(inplace=True),\r\n        )\r\n\r\n        self.expand_4_1 = self.expand(initial_filter_size*2**4, initial_filter_size*2**3)\r\n        self.expand_4_2 = self.expand(initial_filter_size*2**3, initial_filter_size*2**3)\r\n        self.upscale4 = nn.ConvTranspose2d(initial_filter_size*2**3, initial_filter_size*2**2, kernel_size=2, stride=2)\r\n\r\n        self.expand_3_1 = self.expand(initial_filter_size*2**3, initial_filter_size*2**2)\r\n        self.expand_3_2 = self.expand(initial_filter_size*2**2, initial_filter_size*2**2)\r\n        self.upscale3 = nn.ConvTranspose2d(initial_filter_size*2**2, initial_filter_size*2, 2, stride=2)\r\n\r\n        self.expand_2_1 = self.expand(initial_filter_size*2**2, initial_filter_size*2)\r\n        self.expand_2_2 = self.expand(initial_filter_size*2, initial_filter_size*2)\r\n        self.upscale2 = nn.ConvTranspose2d(initial_filter_size*2, initial_filter_size, 2, stride=2)\r\n\r\n        self.expand_1_1 = self.expand(initial_filter_size*2, initial_filter_size)\r\n        self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size)\r\n        # Output layer for segmentation\r\n        self.final = nn.Conv2d(initial_filter_size, num_classes, kernel_size=1)  # kernel size for final layer = 1, see paper\r\n\r\n        self.softmax = torch.nn.Softmax2d()\r\n\r\n        # Output layer for \"autoencoder-mode\"\r\n        self.output_reconstruction_map = nn.Conv2d(initial_filter_size, out_channels=1, kernel_size=1)\r\n\r\n    @staticmethod\r\n    def contract(in_channels, out_channels, kernel_size=3, instancenorm=True):\r\n        if instancenorm:\r\n            layer = nn.Sequential(\r\n                nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\r\n                nn.InstanceNorm2d(out_channels),\r\n                nn.LeakyReLU(inplace=True))\r\n        else:\r\n            layer = nn.Sequential(\r\n                nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\r\n                nn.LeakyReLU(inplace=True))\r\n        return layer\r\n\r\n    @staticmethod\r\n    def expand(in_channels, out_channels, kernel_size=3):\r\n        layer = nn.Sequential(\r\n            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),\r\n            nn.LeakyReLU(inplace=True),\r\n            )\r\n        return layer\r\n\r\n    @staticmethod\r\n    def center_crop(layer, target_width, target_height):\r\n        batch_size, n_channels, layer_width, layer_height = layer.size()\r\n        xy1 = (layer_width - target_width) // 2\r\n        xy2 = (layer_height - target_height) // 2\r\n        return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]\r\n\r\n    def forward(self, x, enable_concat=True, print_layer_shapes=False):\r\n        concat_weight = 1\r\n        if not enable_concat:\r\n            concat_weight = 0\r\n\r\n        contr_1 = self.contr_1_2(self.contr_1_1(x))\r\n        pool = self.pool(contr_1)\r\n\r\n        contr_2 = self.contr_2_2(self.contr_2_1(pool))\r\n        pool = self.pool(contr_2)\r\n\r\n        contr_3 = self.contr_3_2(self.contr_3_1(pool))\r\n        pool = self.pool(contr_3)\r\n\r\n        contr_4 = self.contr_4_2(self.contr_4_1(pool))\r\n        pool = self.pool(contr_4)\r\n\r\n        center = self.center(pool)\r\n\r\n        crop = self.center_crop(contr_4, center.size()[2], center.size()[3])\r\n        concat = torch.cat([center, crop*concat_weight], 1)\r\n\r\n        expand = self.expand_4_2(self.expand_4_1(concat))\r\n        upscale = self.upscale4(expand)\r\n\r\n        crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3])\r\n        concat = torch.cat([upscale, crop*concat_weight], 1)\r\n\r\n        expand = self.expand_3_2(self.expand_3_1(concat))\r\n        upscale = self.upscale3(expand)\r\n\r\n        crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3])\r\n        concat = torch.cat([upscale, crop*concat_weight], 1)\r\n\r\n        expand = self.expand_2_2(self.expand_2_1(concat))\r\n        upscale = self.upscale2(expand)\r\n\r\n        crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3])\r\n        concat = torch.cat([upscale, crop*concat_weight], 1)\r\n\r\n        expand = self.expand_1_2(self.expand_1_1(concat))\r\n\r\n        if enable_concat:\r\n            output = self.final(expand)\r\n        if not enable_concat:\r\n            output = self.output_reconstruction_map(expand)\r\n\r\n        return output\r\n"
  },
  {
    "path": "requirements.txt",
    "content": "googledrivedownloader==0.4\nMedPy==0.4.0\ntorch==1.3.1\ntorchfile==0.1.0\ntrixi==0.1.2.1\nbatchgenerators==0.19.3\n\n# Workaround for scipy issues\nscipy==1.1.0\n\n# Workaround for slackclient version issues\nslackclient==2.0.0\n\n# Fix compatibility issues\ntorchvision==0.4.2\n\n"
  },
  {
    "path": "run_preprocessing.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nfrom configs.Config_unet import get_config\nfrom datasets.example_dataset.create_splits import create_splits\nfrom datasets.utils import download_dataset\nfrom datasets.example_dataset.preprocessing import preprocess_data\n\nif __name__ == \"__main__\":\n    c = get_config()\n\n    download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)\n\n    print('Preprocessing data. [STARTED]')\n    preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name))\n    create_splits(output_dir=c.split_dir, image_dir=c.data_dir)\n    print('Preprocessing data. [DONE]')\n"
  },
  {
    "path": "run_train_pipeline.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom os.path import exists\nfrom configs.Config_unet import get_config\nfrom datasets.example_dataset.create_splits import create_splits\nfrom datasets.utils import download_dataset\nfrom datasets.example_dataset.preprocessing import preprocess_data\nfrom experiments.UNetExperiment import UNetExperiment\n\n\nif __name__ == \"__main__\":\n    c = get_config()\n\n    # print(\"Executing: EPOCHS = {} / LEARNING RATE = {}\".format(c.n_epochs, c.learning_rate))\n\n    download_dataset(dest_path=c.data_root_dir, dataset=c.dataset_name, id=c.google_drive_id)\n\n    if not exists(os.path.join(os.path.join(c.data_root_dir, c.dataset_name), 'preprocessed')):\n        print('Preprocessing data. [STARTED]')\n        preprocess_data(root_dir=os.path.join(c.data_root_dir, c.dataset_name), y_shape=c.patch_size, z_shape=c.patch_size)\n        create_splits(output_dir=c.split_dir, image_dir=c.data_dir)\n        print('Preprocessing data. [DONE]')\n    else:\n        print('The data has already been preprocessed. It will not be preprocessed again. Delete the folder to enforce it.')\n\n    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,\n                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),\n                         # visdomlogger_kwargs={\"auto_start\": c.start_visdom},\n                         loggers={\n                             \"visdom\": (\"visdom\", {\"auto_start\": c.start_visdom})\n                         }\n                         )\n\n    exp.run()\n    exp.run_test(setup=False)\n"
  },
  {
    "path": "runner.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom configs.Config_unet_spleen import get_config\nimport subprocess\n\nif __name__ == \"__main__\":\n    c = get_config()\n    n_epochs = c.n_epochs\n    learning_rate = c.learning_rate\n    step = 0\n\n    while True:\n        result = subprocess.run(['python', 'run_train_pipeline.py',\n                                 '--n_epochs', '{}'.format(n_epochs),\n                                 '--learning_rate', '{}'.format(learning_rate)])\n\n        if divmod(step, 2)[1] == 0:\n            n_epochs = n_epochs + 20\n        else:\n            learning_rate = learning_rate / 2\n        step += 1\n"
  },
  {
    "path": "segment_a_spleen.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\n\nfrom medpy.io import save\n\nfrom configs.Config_unet_spleen import get_config\nfrom datasets.spleen.preprocessing import preprocess_single_file, postprocess_single_image\nfrom experiments.UNetExperiment import UNetExperiment\n\n\ndef save_single_image(image, image_header, filename):\n    # medpy.io.save\n    save(image, filename, image_header)\n    print('> Resulting Image stored as {}'.format(filename))\n\n\nif __name__ == \"__main__\":\n    c = get_config()\n\n    if len(sys.argv) == 1:\n        print(\"USAGE:\\n\\npython {} imagefilename [model_checkpoint [shapesize]]\\n\\n\"\n              \"  imagefilename - a filename that stores a nii.gz formatted file.\\n\"\n              \"  model_checkpoint - a checkpoint filename to reload\\n\"\n              \"  shapesize - optional value that defines \"\n              \"the size of the shape, default is 64 (not yet used).\".format(sys.argv[0]))\n        filename = \"data/Task09_Spleen/imagesTs/spleen_15.nii.gz\"\n    else:\n        filename = sys.argv[1]\n\n    print(\"Loading and processing file {}\".format(filename))\n\n    if len(sys.argv) > 2:\n        c.checkpoint_dir = sys.argv[2]\n        c.do_load_checkpoint = True\n\n    print(\"Loading model from checkpoint {}\".format(c.model_dir))\n    if len(c.model_dir) == 0 or not os.path.isdir(os.path.split(c.model_dir)[0]):\n        print(\"ERROR /!\\\\: No checkpoint dir is set, please provide in Config file.\")\n        exit()\n\n    shapesize = 64\n    if len(sys.argv) > 3:\n        shapesize = int(sys.argv[3])\n\n    # Get the header in order to preserve voxel dimensions to store the segmented image later on\n    print('Preprocessing data.')\n    data, header = preprocess_single_file(filename, y_shape=shapesize, z_shape=shapesize)\n\n    print('Setting up model and start segmentation.')\n    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,\n                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals()\n                         )\n\n    result = exp.segment_single_image(data)\n\n    print('Postprocessing data.')\n    result = postprocess_single_image(result)\n\n    pathname, fname = os.path.split(filename)\n    destination_filename = pathname+\"/segmented_\"+fname\n    print('Saving file to disk: {}'.format(destination_filename))\n    save_single_image(result, header, destination_filename)\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport matplotlib\nmatplotlib.use('Agg')\n\nfrom configs.Config_unet import get_config\nfrom experiments.UNetExperiment import UNetExperiment\n\nif __name__ == \"__main__\":\n    c = get_config()\n\n    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,\n                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),\n                         # visdomlogger_kwargs={\"auto_start\": c.start_visdom},\n                         loggers={\n                             \"visdom\": (\"visdom\", {\"auto_start\": c.start_visdom}),\n                             # \"tb\": (\"tensorboard\"),\n                             # \"slack\": (\"slack\", {\"token\": \"XXXXXXXX\",\n                             #                     \"user_email\": \"x\"})\n                         }\n                         )\n\n    exp.run()\n    exp.run_test(setup=False)\n"
  },
  {
    "path": "train3D.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom configs.Config_unet import get_config\nfrom experiments.UNetExperiment3D import UNetExperiment3D\n\nif __name__ == \"__main__\":\n    c = get_config()\n\n    exp = UNetExperiment3D(config=c, name=c.name, n_epochs=c.n_epochs,\n                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),\n                         # visdomlogger_kwargs={\"auto_start\": c.start_visdom},\n                         loggers={\n                             \"visdom\": (\"visdom\", {\"auto_start\": c.start_visdom}),\n                             # \"tb\": (\"tensorboard\"),\n                             # \"slack\": (\"slack\", {\"token\": \"XXXXXXXX\",\n                             #                     \"user_email\": \"x\"})\n                         }\n                         )\n    exp.run()\n    exp.run_test(setup=False)\n"
  },
  {
    "path": "utilities/__init__.py",
    "content": ""
  },
  {
    "path": "utilities/file_and_folder_operations.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\n\ndef subdirs(folder, join=True, prefix=None, suffix=None, sort=True):\n    if join:\n        l = os.path.join\n    else:\n        l = lambda x, y: y\n    res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))\n            and (prefix is None or i.startswith(prefix))\n            and (suffix is None or i.endswith(suffix))]\n    if sort:\n        res.sort()\n    return res\n\n\ndef subfiles(folder, join=True, prefix=None, suffix=None, sort=True):\n    if join:\n        l = os.path.join\n    else:\n        l = lambda x, y: y\n    res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))\n            and (prefix is None or i.startswith(prefix))\n            and (suffix is None or i.endswith(suffix))]\n    if sort:\n        res.sort()\n    return res\n\n\ndef maybe_mkdir_p(directory):\n    splits = directory.split(\"/\")[1:]\n    for i in range(0, len(splits)):\n        if not os.path.isdir(os.path.join(\"/\", *splits[:i+1])):\n            os.mkdir(os.path.join(\"/\", *splits[:i+1]))\n"
  }
]