Repository: Amshaker/unetr_plus_plus
Branch: main
Commit: dcba1c6e8179
Files: 155
Total size: 1.2 MB
Directory structure:
gitextract_3co428a0/
├── LICENSE
├── README.md
├── evaluation_scripts/
│ ├── run_evaluation_acdc.sh
│ ├── run_evaluation_lung.sh
│ ├── run_evaluation_synapse.sh
│ └── run_evaluation_tumor.sh
├── requirements.txt
├── training_scripts/
│ ├── run_training_acdc.sh
│ ├── run_training_lung.sh
│ ├── run_training_synapse.sh
│ └── run_training_tumor.sh
└── unetr_pp/
├── __init__.py
├── configuration.py
├── evaluation/
│ ├── __init__.py
│ ├── add_dummy_task_with_mean_over_all_tasks.py
│ ├── add_mean_dice_to_json.py
│ ├── collect_results_files.py
│ ├── evaluator.py
│ ├── metrics.py
│ ├── model_selection/
│ │ ├── __init__.py
│ │ ├── collect_all_fold0_results_and_summarize_in_one_csv.py
│ │ ├── ensemble.py
│ │ ├── figure_out_what_to_submit.py
│ │ ├── rank_candidates.py
│ │ ├── rank_candidates_StructSeg.py
│ │ ├── rank_candidates_cascade.py
│ │ ├── summarize_results_in_one_json.py
│ │ └── summarize_results_with_plans.py
│ ├── region_based_evaluation.py
│ ├── surface_dice.py
│ ├── unetr_pp_acdc_checkpoint/
│ │ └── unetr_pp/
│ │ └── 3d_fullres/
│ │ └── Task001_ACDC/
│ │ └── unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/
│ │ └── fold_0/
│ │ └── .gitignore
│ ├── unetr_pp_lung_checkpoint/
│ │ └── unetr_pp/
│ │ └── 3d_fullres/
│ │ └── Task006_Lung/
│ │ └── unetr_pp_trainer_lung__unetr_pp_Plansv2.1/
│ │ └── fold_0/
│ │ └── .gitignore
│ ├── unetr_pp_synapse_checkpoint/
│ │ └── unetr_pp/
│ │ └── 3d_fullres/
│ │ └── Task002_Synapse/
│ │ └── unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/
│ │ └── fold_0/
│ │ └── .gitignore
│ └── unetr_pp_tumor_checkpoint/
│ └── unetr_pp/
│ └── 3d_fullres/
│ └── Task003_tumor/
│ └── unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/
│ └── fold_0/
│ └── .gitignore
├── experiment_planning/
│ ├── DatasetAnalyzer.py
│ ├── __init__.py
│ ├── alternative_experiment_planning/
│ │ ├── experiment_planner_baseline_3DUNet_v21_11GB.py
│ │ ├── experiment_planner_baseline_3DUNet_v21_16GB.py
│ │ ├── experiment_planner_baseline_3DUNet_v21_32GB.py
│ │ ├── experiment_planner_baseline_3DUNet_v21_3convperstage.py
│ │ ├── experiment_planner_baseline_3DUNet_v22.py
│ │ ├── experiment_planner_baseline_3DUNet_v23.py
│ │ ├── experiment_planner_residual_3DUNet_v21.py
│ │ ├── normalization/
│ │ │ ├── experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py
│ │ │ ├── experiment_planner_3DUNet_CT2.py
│ │ │ └── experiment_planner_3DUNet_nonCT.py
│ │ ├── patch_size/
│ │ │ ├── experiment_planner_3DUNet_isotropic_in_mm.py
│ │ │ └── experiment_planner_3DUNet_isotropic_in_voxels.py
│ │ ├── pooling_and_convs/
│ │ │ ├── experiment_planner_baseline_3DUNet_allConv3x3.py
│ │ │ └── experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py
│ │ └── target_spacing/
│ │ ├── experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py
│ │ ├── experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py
│ │ └── experiment_planner_baseline_3DUNet_v21_noResampling.py
│ ├── change_batch_size.py
│ ├── common_utils.py
│ ├── experiment_planner_baseline_2DUNet.py
│ ├── experiment_planner_baseline_2DUNet_v21.py
│ ├── experiment_planner_baseline_3DUNet.py
│ ├── experiment_planner_baseline_3DUNet_v21.py
│ ├── nnFormer_convert_decathlon_task.py
│ ├── nnFormer_plan_and_preprocess.py
│ ├── summarize_plans.py
│ └── utils.py
├── inference/
│ ├── __init__.py
│ ├── inferTs/
│ │ └── swin_nomask_2/
│ │ └── plans.pkl
│ ├── predict.py
│ ├── predict_simple.py
│ └── segmentation_export.py
├── inference_acdc.py
├── inference_synapse.py
├── inference_tumor.py
├── network_architecture/
│ ├── README.md
│ ├── __init__.py
│ ├── acdc/
│ │ ├── __init__.py
│ │ ├── model_components.py
│ │ ├── transformerblock.py
│ │ └── unetr_pp_acdc.py
│ ├── dynunet_block.py
│ ├── generic_UNet.py
│ ├── initialization.py
│ ├── layers.py
│ ├── lung/
│ │ ├── __init__.py
│ │ ├── model_components.py
│ │ ├── transformerblock.py
│ │ └── unetr_pp_lung.py
│ ├── neural_network.py
│ ├── synapse/
│ │ ├── __init__.py
│ │ ├── model_components.py
│ │ ├── transformerblock.py
│ │ └── unetr_pp_synapse.py
│ └── tumor/
│ ├── __init__.py
│ ├── model_components.py
│ ├── transformerblock.py
│ └── unetr_pp_tumor.py
├── paths.py
├── postprocessing/
│ ├── connected_components.py
│ ├── consolidate_all_for_paper.py
│ ├── consolidate_postprocessing.py
│ └── consolidate_postprocessing_simple.py
├── preprocessing/
│ ├── cropping.py
│ ├── custom_preprocessors/
│ │ └── preprocessor_scale_RGB_to_0_1.py
│ ├── preprocessing.py
│ └── sanity_checks.py
├── run/
│ ├── __init__.py
│ ├── default_configuration.py
│ └── run_training.py
├── training/
│ ├── __init__.py
│ ├── cascade_stuff/
│ │ ├── __init__.py
│ │ └── predict_next_stage.py
│ ├── data_augmentation/
│ │ ├── __init__.py
│ │ ├── custom_transforms.py
│ │ ├── data_augmentation_insaneDA.py
│ │ ├── data_augmentation_insaneDA2.py
│ │ ├── data_augmentation_moreDA.py
│ │ ├── data_augmentation_noDA.py
│ │ ├── default_data_augmentation.py
│ │ ├── downsampling.py
│ │ └── pyramid_augmentations.py
│ ├── dataloading/
│ │ ├── __init__.py
│ │ └── dataset_loading.py
│ ├── learning_rate/
│ │ └── poly_lr.py
│ ├── loss_functions/
│ │ ├── TopK_loss.py
│ │ ├── __init__.py
│ │ ├── crossentropy.py
│ │ ├── deep_supervision.py
│ │ └── dice_loss.py
│ ├── model_restore.py
│ ├── network_training/
│ │ ├── Trainer_acdc.py
│ │ ├── Trainer_lung.py
│ │ ├── Trainer_synapse.py
│ │ ├── Trainer_tumor.py
│ │ ├── network_trainer_acdc.py
│ │ ├── network_trainer_lung.py
│ │ ├── network_trainer_synapse.py
│ │ ├── network_trainer_tumor.py
│ │ ├── unetr_pp_trainer_acdc.py
│ │ ├── unetr_pp_trainer_lung.py
│ │ ├── unetr_pp_trainer_synapse.py
│ │ └── unetr_pp_trainer_tumor.py
│ └── optimizer/
│ └── ranger.py
└── utilities/
├── __init__.py
├── distributed.py
├── file_conversions.py
├── file_endings.py
├── folder_names.py
├── nd_softmax.py
├── one_hot_encoding.py
├── overlay_plots.py
├── random_stuff.py
├── recursive_delete_npz.py
├── recursive_rename_taskXX_to_taskXXX.py
├── sitk_stuff.py
├── task_name_id_conversion.py
├── tensor_utilities.py
└── to_torch.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2024] [Abdelrahman Mohamed Shaker Youssief]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation

[Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ)*1, [Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra)1, [Hanoona Rasheed](https://scholar.google.com/citations?user=yhDdEuEAAAAJ&hl=en&authuser=1&oi=sra)1, [Salman Khan](https://salman-h-khan.github.io/)1, [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en)2,3 and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en)1,4
Mohamed Bin Zayed University of Artificial Intelligence1, University of California Merced2, Google Research3, Linkoping University4
[](https://ieeexplore.ieee.org/document/10526382)
[](https://amshaker.github.io/unetr_plus_plus/)
[](https://docs.google.com/presentation/d/e/2PACX-1vRtrxSfA2kU1fBmPxBdMQioLfsvjcjWBoaOVf3aupqajm0mw_C4TEz05Yk4ZF_vqoMyA8iiUJE60ynm/pub?start=true&loop=false&delayms=10000)
## :rocket: News
* **(Dec 2025):** UNETR++ is now available in Keras 3 with a simple model initialization as part of [AI Toolkit for Healthcare Imaging](https://github.com/innat/medic-ai/tree/main/medicai/models/unetr_plus_plus)!
* **(Oct 2025):** UNETR++ became the **#1 most popular article** of 2025 in [IEEE Transactions on Medical Imaging (TMI)](https://ieeexplore.ieee.org/xpl/topAccessedArticles.jsp?punumber=42)! 🎉
* **(May 04, 2024):** We're thrilled to share that UNETR++ has been accepted to IEEE TMI-2024! 🎊
* **(Jun 01, 2023):** UNETR++ code & weights are released for Decathlon-Lung and BRaTs.
* **(Dec 15, 2022):** UNETR++ weights are released for Synapse & ACDC datasets.
* **(Dec 09, 2022):** UNETR++ training and evaluation codes are released.

> **Abstract:** *Owing to the success of transformer models, recent works study their applicability in 3D medical segmentation tasks. Within the transformer models, the self-attention mechanism is one of the main building blocks that strives to capture long-range dependencies. However, the self-attention operation has quadratic complexity, which proves to be a computational bottleneck, especially in volumetric medical imaging, where the inputs are 3D with numerous slices. In this paper, we propose a 3D medical image segmentation approach, named UNETR++, that offers both high-quality segmentation masks as well as efficiency in terms of parameters, compute cost, and inference speed. The core of our design is the introduction of a novel efficient paired attention (EPA) block that efficiently learns spatial and channel-wise discriminative features using a pair of inter-dependent branches based on spatial and channel attention.
Our spatial attention formulation is efficient, having linear complexity with respect to the input sequence length. To enable communication between spatial and channel-focused branches, we share the weights of query and key mapping functions that provide a complementary benefit (paired attention), while also reducing the overall network parameters. Our extensive evaluations on five benchmarks, Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung, reveal the effectiveness of our contributions in terms of both efficiency and accuracy. On Synapse, our UNETR++ sets a new state-of-the-art with a Dice Score of 87.2%, while being significantly efficient with a reduction of over 71% in terms of both parameters and FLOPs, compared to the best method in the literature.*
## Architecture overview of UNETR++
Overview of our UNETR++ approach with hierarchical encoder-decoder structure. The 3D patches are fed to the encoder, whose outputs are then connected to the decoder via skip connections followed by convolutional blocks to produce the final segmentation mask. The focus of our design is the introduction of an _efficient paired-attention_ (EPA) block. Each EPA block performs two tasks using parallel attention modules with shared keys-queries and different value layers to efficiently learn enriched spatial-channel feature representations. As illustrated in the EPA block diagram (on the right), the first (top) attention module aggregates the spatial features by a weighted sum of the projected features in a linear manner to compute the spatial attention maps, while the second (bottom) attention module emphasizes the dependencies in the channels and computes the channel attention maps. Finally, the outputs of the two attention modules are fused and passed to convolutional blocks to enhance the feature representation, leading to better segmentation masks.

## Results
### Synapse Dataset
State-of-the-art comparison on the abdominal multi-organ Synapse dataset. We report both the segmentation performance (DSC, HD95) and model complexity (parameters and FLOPs).
Our proposed UNETR++ achieves favorable segmentation performance against existing methods, while considerably reducing the model complexity. Best results are in bold.
Abbreviations stand for: Spl: _spleen_, RKid: _right kidney_, LKid: _left kidney_, Gal: _gallbladder_, Liv: _liver_, Sto: _stomach_, Aor: _aorta_, Pan: _pancreas_.
Best results are in bold.

## Qualitative Comparison
### Synapse Dataset
Qualitative comparison on multi-organ segmentation task. Here, we compare our UNETR++ with existing methods: UNETR, Swin UNETR, and nnFormer.
The different abdominal organs are shown in the legend below the examples. Existing methods struggle to correctly segment different organs (marked in red dashed box).
Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs.

### ACDC Dataset
Qualitative comparison on the ACDC dataset. We compare our UNETR++ with existing methods: UNETR and nnFormer. It is noticeable that the existing methods struggle to correctly segment different organs (marked in red dashed box). Our UNETR++ achieves favorable segmentation performance by accurately segmenting the organs. Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs.

## Installation
The code is tested with PyTorch 1.11.0 and CUDA 11.3. After cloning the repository, follow the below steps for installation,
1. Create and activate conda environment
```shell
conda create --name unetr_pp python=3.8
conda activate unetr_pp
```
2. Install PyTorch and torchvision
```shell
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
```
3. Install other dependencies
```shell
pip install -r requirements.txt
```
## Dataset
We follow the same dataset preprocessing as in [nnFormer](https://github.com/282857341/nnFormer). We conducted extensive experiments on five benchmarks: Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung.
The dataset folders for Synapse should be organized as follows:
```
./DATASET_Synapse/
├── unetr_pp_raw/
├── unetr_pp_raw_data/
├── Task02_Synapse/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task002_Synapse
├── unetr_pp_cropped_data/
├── Task002_Synapse
```
The dataset folders for ACDC should be organized as follows:
```
./DATASET_Acdc/
├── unetr_pp_raw/
├── unetr_pp_raw_data/
├── Task01_ACDC/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task001_ACDC
├── unetr_pp_cropped_data/
├── Task001_ACDC
```
The dataset folders for Decathlon-Lung should be organized as follows:
```
./DATASET_Lungs/
├── unetr_pp_raw/
├── unetr_pp_raw_data/
├── Task06_Lung/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task006_Lung
├── unetr_pp_cropped_data/
├── Task006_Lung
```
The dataset folders for BRaTs should be organized as follows:
```
./DATASET_Tumor/
├── unetr_pp_raw/
├── unetr_pp_raw_data/
├── Task03_tumor/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task003_tumor
├── unetr_pp_cropped_data/
├── Task003_tumor
```
Please refer to [Setting up the datasets](https://github.com/282857341/nnFormer) on nnFormer repository for more details.
Alternatively, you can download the preprocessed dataset for [Synapse](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EbHDhSjkQW5Ak9SMPnGCyb8BOID98wdg3uUvQ0eNvTZ8RA?e=YVhfdg), [ACDC](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EY9qieTkT3JFrhCJQiwZXdsB1hJ4ebVAtNdBNOs2HAo3CQ?e=VwfFHC), [Decathlon-Lung](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EWhU1T7c-mNKgkS2PQjFwP0B810LCiX3D2CvCES2pHDVSg?e=OqcIW3), [BRaTs](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EaQOxpD2yE5Btl-UEBAbQa0BYFBCL4J2Ph-VF_sqZlBPSQ?e=DFY41h), and extract it under the project directory.
## Training
The following scripts can be used for training our UNETR++ model on the datasets:
```shell
bash training_scripts/run_training_synapse.sh
bash training_scripts/run_training_acdc.sh
bash training_scripts/run_training_lung.sh
bash training_scripts/run_training_tumor.sh
```
## Evaluation
To reproduce the results of UNETR++:
1- Download [Synapse weights](https://drive.google.com/file/d/13JuLMeDQRR_a3c3tr2V2oav6I29fJoBa) and paste ```model_final_checkpoint.model``` in the following path:
```shell
unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/
```
Then, run
```shell
bash evaluation_scripts/run_evaluation_synapse.sh
```
2- Download [ACDC weights](https://drive.google.com/file/d/15YXiHai1zLc1ycmXaiSHetYbLGum3tV5) and paste ```model_final_checkpoint.model``` it in the following path:
```shell
unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/
```
Then, run
```shell
bash evaluation_scripts/run_evaluation_acdc.sh
```
3- Download [Decathlon-Lung weights](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/ETAlc8WTjV1BhZx7zwFpA8UBS4og6upb1qX2UKkypMoTjw?e=KfzAiG) and paste ```model_final_checkpoint.model``` it in the following path:
```shell
unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/
```
Then, run
```shell
bash evaluation_scripts/run_evaluation_lung.sh
```
4- Download [BRaTs weights](https://drive.google.com/file/d/1LiqnVKKv3DrDKvo6J0oClhIFirhaz5PG) and paste ```model_final_checkpoint.model``` it in the following path:
```shell
unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/
```
Then, run
```shell
bash evaluation_scripts/run_evaluation_tumor.sh
```
## Acknowledgement
This repository is built based on [nnFormer](https://github.com/282857341/nnFormer) repository.
## Citation
If you use our work, please consider citing:
```bibtex
@ARTICLE{10526382,
title={UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation},
author={Shaker, Abdelrahman M. and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
journal={IEEE Transactions on Medical Imaging},
year={2024},
doi={10.1109/TMI.2024.3398728}}
```
## Contact
Should you have any question, please create an issue on this repository or contact me at abdelrahman.youssief@mbzuai.ac.ae.
================================================
FILE: evaluation_scripts/run_evaluation_acdc.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Acdc
CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_acdc_checkpoint
export PYTHONPATH=.././
export RESULTS_FOLDER="$CHECKPOINT_PATH"
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task01_ACDC
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_acdc 1 0 -val
================================================
FILE: evaluation_scripts/run_evaluation_lung.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Lungs
CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_lung_checkpoint
export PYTHONPATH=.././
export RESULTS_FOLDER="$CHECKPOINT_PATH"
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_lung 6 0 -val
================================================
FILE: evaluation_scripts/run_evaluation_synapse.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Synapse
CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_synapse_checkpoint
export PYTHONPATH=.././
export RESULTS_FOLDER="$CHECKPOINT_PATH"
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task02_Synapse
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_synapse 2 0 -val
================================================
FILE: evaluation_scripts/run_evaluation_tumor.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Tumor
export PYTHONPATH=.././
export RESULTS_FOLDER=../unetr_pp/evaluation/unetr_pp_tumor_checkpoint
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
# Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without training
python ../unetr_pp/inference/predict_simple.py -i ../unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o ../unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumor
python ../unetr_pp/inference_tumor.py 0
================================================
FILE: requirements.txt
================================================
argparse==1.4.0
numpy==1.20.1
batchgenerators==0.21
matplotlib==3.5.1
typing==3.7.4.3
sklearn==0.0
scikit-learn==1.0.2
tqdm==4.32.1
fvcore==0.1.5.post20220414
scikit-image==0.18.1
simpleitk==2.2.0
tifffile==2021.3.31
medpy==0.4.0
pandas==1.2.3
scipy==1.6.2
nibabel==4.0.1
timm==0.4.12
monai==0.7.0
einops==0.6.0
tensorboardX==2.2
================================================
FILE: training_scripts/run_training_acdc.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Acdc
export PYTHONPATH=.././
export RESULTS_FOLDER=../output_acdc
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task01_ACDC
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_acdc 1 0
================================================
FILE: training_scripts/run_training_lung.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Lungs
export PYTHONPATH=.././
export RESULTS_FOLDER=../output_lung
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_lung 6 0
================================================
FILE: training_scripts/run_training_synapse.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Synapse
export PYTHONPATH=.././
export RESULTS_FOLDER=../output_synapse
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task02_Synapse
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_synapse 2 0
================================================
FILE: training_scripts/run_training_tumor.sh
================================================
#!/bin/sh
DATASET_PATH=../DATASET_Tumor
export PYTHONPATH=.././
export RESULTS_FOLDER=../output_tumor
export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor
export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw
python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_tumor 3 0
================================================
FILE: unetr_pp/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/configuration.py
================================================
import os
default_num_threads = 8 if 'nnFormer_def_n_proc' not in os.environ else int(os.environ['nnFormer_def_n_proc'])
RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD = 3 # determines what threshold to use for resampling the low resolution axis
# separately (with NN)
================================================
FILE: unetr_pp/evaluation/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/evaluation/add_dummy_task_with_mean_over_all_tasks.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles
import os
from collections import OrderedDict
folder = "/home/fabian/drives/E132-Projekte/Projects/2018_MedicalDecathlon/Leaderboard"
task_descriptors = ['2D final 2',
'2D final, less pool, dc and topK, fold0',
'2D final pseudo3d 7, fold0',
'2D final, less pool, dc and ce, fold0',
'3D stage0 final 2, fold0',
'3D fullres final 2, fold0']
task_ids_with_no_stage0 = ["Task001_BrainTumour", "Task004_Hippocampus", "Task005_Prostate"]
mean_scores = OrderedDict()
for t in task_descriptors:
mean_scores[t] = OrderedDict()
json_files = subfiles(folder, True, None, ".json", True)
json_files = [i for i in json_files if not i.split("/")[-1].startswith(".")] # stupid mac
for j in json_files:
with open(j, 'r') as f:
res = json.load(f)
task = res['task']
if task != "Task999_ALL":
name = res['name']
if name in task_descriptors:
if task not in list(mean_scores[name].keys()):
mean_scores[name][task] = res['results']['mean']['mean']
else:
raise RuntimeError("duplicate task %s for description %s" % (task, name))
for t in task_ids_with_no_stage0:
mean_scores["3D stage0 final 2, fold0"][t] = mean_scores["3D fullres final 2, fold0"][t]
a = set()
for i in mean_scores.keys():
a = a.union(list(mean_scores[i].keys()))
for i in mean_scores.keys():
try:
for t in list(a):
assert t in mean_scores[i].keys(), "did not find task %s for experiment %s" % (t, i)
new_res = OrderedDict()
new_res['name'] = i
new_res['author'] = "Fabian"
new_res['task'] = "Task999_ALL"
new_res['results'] = OrderedDict()
new_res['results']['mean'] = OrderedDict()
new_res['results']['mean']['mean'] = OrderedDict()
tasks = list(mean_scores[i].keys())
metrics = mean_scores[i][tasks[0]].keys()
for m in metrics:
foreground_values = [mean_scores[i][n][m] for n in tasks]
new_res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
output_fname = i.replace(" ", "_") + "_globalMean.json"
with open(os.path.join(folder, output_fname), 'w') as f:
json.dump(new_res, f)
except AssertionError:
print("could not process experiment %s" % i)
print("did not find task %s for experiment %s" % (t, i))
================================================
FILE: unetr_pp/evaluation/add_mean_dice_to_json.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles
from collections import OrderedDict
def foreground_mean(filename):
with open(filename, 'r') as f:
res = json.load(f)
class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean')])
class_ids = class_ids[class_ids != 0]
class_ids = class_ids[class_ids != -1]
class_ids = class_ids[class_ids != 99]
tmp = res['results']['mean'].get('99')
if tmp is not None:
_ = res['results']['mean'].pop('99')
metrics = res['results']['mean']['1'].keys()
res['results']['mean']["mean"] = OrderedDict()
for m in metrics:
foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids]
res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
with open(filename, 'w') as f:
json.dump(res, f, indent=4, sort_keys=True)
def run_in_folder(folder):
json_files = subfiles(folder, True, None, ".json", True)
json_files = [i for i in json_files if not i.split("/")[-1].startswith(".") and not i.endswith("_globalMean.json")] # stupid mac
for j in json_files:
foreground_mean(j)
if __name__ == "__main__":
folder = "/media/fabian/Results/nnFormerOutput_final/summary_jsons"
run_in_folder(folder)
================================================
FILE: unetr_pp/evaluation/collect_results_files.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from batchgenerators.utilities.file_and_folder_operations import subdirs, subfiles
def crawl_and_copy(current_folder, out_folder, prefix="fabian_", suffix="ummary.json"):
"""
This script will run recursively through all subfolders of current_folder and copy all files that end with
suffix with some automatically generated prefix into out_folder
:param current_folder:
:param out_folder:
:param prefix:
:return:
"""
s = subdirs(current_folder, join=False)
f = subfiles(current_folder, join=False)
f = [i for i in f if i.endswith(suffix)]
if current_folder.find("fold0") != -1:
for fl in f:
shutil.copy(os.path.join(current_folder, fl), os.path.join(out_folder, prefix+fl))
for su in s:
if prefix == "":
add = su
else:
add = "__" + su
crawl_and_copy(os.path.join(current_folder, su), out_folder, prefix=prefix+add)
if __name__ == "__main__":
from unetr_pp.paths import network_training_output_dir
output_folder = "/home/fabian/PhD/results/nnFormerV2/leaderboard"
crawl_and_copy(network_training_output_dir, output_folder)
from unetr_pp.evaluation.add_mean_dice_to_json import run_in_folder
run_in_folder(output_folder)
================================================
FILE: unetr_pp/evaluation/evaluator.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import inspect
import json
import hashlib
from datetime import datetime
from multiprocessing.pool import Pool
import numpy as np
import pandas as pd
import SimpleITK as sitk
from unetr_pp.evaluation.metrics import ConfusionMatrix, ALL_METRICS
from batchgenerators.utilities.file_and_folder_operations import save_json, subfiles, join
from collections import OrderedDict
class Evaluator:
"""Object that holds test and reference segmentations with label information
and computes a number of metrics on the two. 'labels' must either be an
iterable of numeric values (or tuples thereof) or a dictionary with string
names and numeric values.
"""
default_metrics = [
"False Positive Rate",
"Dice",
"Jaccard",
"Precision",
"Recall",
"Accuracy",
"False Omission Rate",
"Negative Predictive Value",
"False Negative Rate",
"True Negative Rate",
"False Discovery Rate",
"Total Positives Test",
"Total Positives Reference"
]
default_advanced_metrics = [
#"Hausdorff Distance",
"Hausdorff Distance 95",
#"Avg. Surface Distance",
#"Avg. Symmetric Surface Distance"
]
def __init__(self,
test=None,
reference=None,
labels=None,
metrics=None,
advanced_metrics=None,
nan_for_nonexisting=True):
self.test = None
self.reference = None
self.confusion_matrix = ConfusionMatrix()
self.labels = None
self.nan_for_nonexisting = nan_for_nonexisting
self.result = None
self.metrics = []
if metrics is None:
for m in self.default_metrics:
self.metrics.append(m)
else:
for m in metrics:
self.metrics.append(m)
self.advanced_metrics = []
if advanced_metrics is None:
for m in self.default_advanced_metrics:
self.advanced_metrics.append(m)
else:
for m in advanced_metrics:
self.advanced_metrics.append(m)
self.set_reference(reference)
self.set_test(test)
if labels is not None:
self.set_labels(labels)
else:
if test is not None and reference is not None:
self.construct_labels()
def set_test(self, test):
"""Set the test segmentation."""
self.test = test
def set_reference(self, reference):
"""Set the reference segmentation."""
self.reference = reference
def set_labels(self, labels):
"""Set the labels.
:param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels
will only have names if you pass a dictionary"""
if isinstance(labels, dict):
self.labels = collections.OrderedDict(labels)
elif isinstance(labels, set):
self.labels = list(labels)
elif isinstance(labels, np.ndarray):
self.labels = [i for i in labels]
elif isinstance(labels, (list, tuple)):
self.labels = labels
else:
raise TypeError("Can only handle dict, list, tuple, set & numpy array, but input is of type {}".format(type(labels)))
def construct_labels(self):
"""Construct label set from unique entries in segmentations."""
if self.test is None and self.reference is None:
raise ValueError("No test or reference segmentations.")
elif self.test is None:
labels = np.unique(self.reference)
else:
labels = np.union1d(np.unique(self.test),
np.unique(self.reference))
self.labels = list(map(lambda x: int(x), labels))
def set_metrics(self, metrics):
"""Set evaluation metrics"""
if isinstance(metrics, set):
self.metrics = list(metrics)
elif isinstance(metrics, (list, tuple, np.ndarray)):
self.metrics = metrics
else:
raise TypeError("Can only handle list, tuple, set & numpy array, but input is of type {}".format(type(metrics)))
def add_metric(self, metric):
if metric not in self.metrics:
self.metrics.append(metric)
def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs):
"""Compute metrics for segmentations."""
if test is not None:
self.set_test(test)
if reference is not None:
self.set_reference(reference)
if self.test is None or self.reference is None:
raise ValueError("Need both test and reference segmentations.")
if self.labels is None:
self.construct_labels()
self.metrics.sort()
# get functions for evaluation
# somewhat convoluted, but allows users to define additonal metrics
# on the fly, e.g. inside an IPython console
_funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics}
frames = inspect.getouterframes(inspect.currentframe())
for metric in self.metrics:
for f in frames:
if metric in f[0].f_locals:
_funcs[metric] = f[0].f_locals[metric]
break
else:
if metric in _funcs:
continue
else:
raise NotImplementedError(
"Metric {} not implemented.".format(metric))
# get results
self.result = OrderedDict()
eval_metrics = self.metrics
if advanced:
eval_metrics += self.advanced_metrics
if isinstance(self.labels, dict):
for label, name in self.labels.items():
k = str(name)
self.result[k] = OrderedDict()
if not hasattr(label, "__iter__"):
self.confusion_matrix.set_test(self.test == label)
self.confusion_matrix.set_reference(self.reference == label)
else:
current_test = 0
current_reference = 0
for l in label:
current_test += (self.test == l)
current_reference += (self.reference == l)
self.confusion_matrix.set_test(current_test)
self.confusion_matrix.set_reference(current_reference)
for metric in eval_metrics:
self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
nan_for_nonexisting=self.nan_for_nonexisting,
**metric_kwargs)
else:
for i, l in enumerate(self.labels):
k = str(l)
self.result[k] = OrderedDict()
self.confusion_matrix.set_test(self.test == l)
self.confusion_matrix.set_reference(self.reference == l)
for metric in eval_metrics:
self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
nan_for_nonexisting=self.nan_for_nonexisting,
**metric_kwargs)
return self.result
def to_dict(self):
if self.result is None:
self.evaluate()
return self.result
def to_array(self):
"""Return result as numpy array (labels x metrics)."""
if self.result is None:
self.evaluate
result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32)
if isinstance(self.labels, dict):
for i, label in enumerate(self.labels.keys()):
for j, metric in enumerate(result_metrics):
a[i][j] = self.result[self.labels[label]][metric]
else:
for i, label in enumerate(self.labels):
for j, metric in enumerate(result_metrics):
a[i][j] = self.result[label][metric]
return a
def to_pandas(self):
"""Return result as pandas DataFrame."""
a = self.to_array()
if isinstance(self.labels, dict):
labels = list(self.labels.values())
else:
labels = self.labels
result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
return pd.DataFrame(a, index=labels, columns=result_metrics)
class NiftiEvaluator(Evaluator):
def __init__(self, *args, **kwargs):
self.test_nifti = None
self.reference_nifti = None
super(NiftiEvaluator, self).__init__(*args, **kwargs)
def set_test(self, test):
"""Set the test segmentation."""
if test is not None:
self.test_nifti = sitk.ReadImage(test)
super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti))
else:
self.test_nifti = None
super(NiftiEvaluator, self).set_test(test)
def set_reference(self, reference):
"""Set the reference segmentation."""
if reference is not None:
self.reference_nifti = sitk.ReadImage(reference)
super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti))
else:
self.reference_nifti = None
super(NiftiEvaluator, self).set_reference(reference)
def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs):
if voxel_spacing is None:
voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1]
metric_kwargs["voxel_spacing"] = voxel_spacing
return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs)
def run_evaluation(args):
test, ref, evaluator, metric_kwargs = args
# evaluate
evaluator.set_test(test)
evaluator.set_reference(ref)
if evaluator.labels is None:
evaluator.construct_labels()
current_scores = evaluator.evaluate(**metric_kwargs)
if type(test) == str:
current_scores["test"] = test
if type(ref) == str:
current_scores["reference"] = ref
return current_scores
def aggregate_scores(test_ref_pairs,
evaluator=NiftiEvaluator,
labels=None,
nanmean=True,
json_output_file=None,
json_name="",
json_description="",
json_author="Fabian",
json_task="",
num_threads=2,
**metric_kwargs):
"""
test = predicted image
:param test_ref_pairs:
:param evaluator:
:param labels: must be a dict of int-> str or a list of int
:param nanmean:
:param json_output_file:
:param json_name:
:param json_description:
:param json_author:
:param json_task:
:param metric_kwargs:
:return:
"""
if type(evaluator) == type:
evaluator = evaluator()
if labels is not None:
evaluator.set_labels(labels)
all_scores = OrderedDict()
all_scores["all"] = []
all_scores["mean"] = OrderedDict()
test = [i[0] for i in test_ref_pairs]
ref = [i[1] for i in test_ref_pairs]
p = Pool(num_threads)
all_res = p.map(run_evaluation, zip(test, ref, [evaluator]*len(ref), [metric_kwargs]*len(ref)))
p.close()
p.join()
for i in range(len(all_res)):
all_scores["all"].append(all_res[i])
# append score list for mean
for label, score_dict in all_res[i].items():
if label in ("test", "reference"):
continue
if label not in all_scores["mean"]:
all_scores["mean"][label] = OrderedDict()
for score, value in score_dict.items():
if score not in all_scores["mean"][label]:
all_scores["mean"][label][score] = []
all_scores["mean"][label][score].append(value)
for label in all_scores["mean"]:
for score in all_scores["mean"][label]:
if nanmean:
all_scores["mean"][label][score] = float(np.nanmean(all_scores["mean"][label][score]))
else:
all_scores["mean"][label][score] = float(np.mean(all_scores["mean"][label][score]))
# save to file if desired
# we create a hopefully unique id by hashing the entire output dictionary
if json_output_file is not None:
json_dict = OrderedDict()
json_dict["name"] = json_name
json_dict["description"] = json_description
timestamp = datetime.today()
json_dict["timestamp"] = str(timestamp)
json_dict["task"] = json_task
json_dict["author"] = json_author
json_dict["results"] = all_scores
json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
save_json(json_dict, json_output_file)
return all_scores
def aggregate_scores_for_experiment(score_file,
labels=None,
metrics=Evaluator.default_metrics,
nanmean=True,
json_output_file=None,
json_name="",
json_description="",
json_author="Fabian",
json_task=""):
scores = np.load(score_file)
scores_mean = scores.mean(0)
if labels is None:
labels = list(map(str, range(scores.shape[1])))
results = []
results_mean = OrderedDict()
for i in range(scores.shape[0]):
results.append(OrderedDict())
for l, label in enumerate(labels):
results[-1][label] = OrderedDict()
results_mean[label] = OrderedDict()
for m, metric in enumerate(metrics):
results[-1][label][metric] = float(scores[i][l][m])
results_mean[label][metric] = float(scores_mean[l][m])
json_dict = OrderedDict()
json_dict["name"] = json_name
json_dict["description"] = json_description
timestamp = datetime.today()
json_dict["timestamp"] = str(timestamp)
json_dict["task"] = json_task
json_dict["author"] = json_author
json_dict["results"] = {"all": results, "mean": results_mean}
json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
if json_output_file is not None:
json_output_file = open(json_output_file, "w")
json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
json_output_file.close()
return json_dict
def evaluate_folder(folder_with_gts: str, folder_with_predictions: str, labels: tuple, **metric_kwargs):
"""
writes a summary.json to folder_with_predictions
:param folder_with_gts: folder where the ground truth segmentations are saved. Must be nifti files.
:param folder_with_predictions: folder where the predicted segmentations are saved. Must be nifti files.
:param labels: tuple of int with the labels in the dataset. For example (0, 1, 2, 3) for Task001_BrainTumour.
:return:
"""
files_gt = subfiles(folder_with_gts, suffix=".nii.gz", join=False)
files_pred = subfiles(folder_with_predictions, suffix=".nii.gz", join=False)
assert all([i in files_pred for i in files_gt]), "files missing in folder_with_predictions"
assert all([i in files_gt for i in files_pred]), "files missing in folder_with_gts"
test_ref_pairs = [(join(folder_with_predictions, i), join(folder_with_gts, i)) for i in files_pred]
res = aggregate_scores(test_ref_pairs, json_output_file=join(folder_with_predictions, "summary.json"),
num_threads=8, labels=labels, **metric_kwargs)
return res
def nnformer_evaluate_folder():
import argparse
parser = argparse.ArgumentParser("Evaluates the segmentations located in the folder pred. Output of this script is "
"a json file. At the very bottom of the json file is going to be a 'mean' "
"entry with averages metrics across all cases")
parser.add_argument('-ref', required=True, type=str, help="Folder containing the reference segmentations in nifti "
"format.")
parser.add_argument('-pred', required=True, type=str, help="Folder containing the predicted segmentations in nifti "
"format. File names must match between the folders!")
parser.add_argument('-l', nargs='+', type=int, required=True, help="List of label IDs (integer values) that should "
"be evaluated. Best practice is to use all int "
"values present in the dataset, so for example "
"for LiTS the labels are 0: background, 1: "
"liver, 2: tumor. So this argument "
"should be -l 1 2. You can if you want also "
"evaluate the background label (0) but in "
"this case that would not gie any useful "
"information.")
args = parser.parse_args()
return evaluate_folder(args.ref, args.pred, args.l)
================================================
FILE: unetr_pp/evaluation/metrics.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from medpy import metric
def assert_shape(test, reference):
assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
test.shape, reference.shape)
class ConfusionMatrix:
def __init__(self, test=None, reference=None):
self.tp = None
self.fp = None
self.tn = None
self.fn = None
self.size = None
self.reference_empty = None
self.reference_full = None
self.test_empty = None
self.test_full = None
self.set_reference(reference)
self.set_test(test)
def set_test(self, test):
self.test = test
self.reset()
def set_reference(self, reference):
self.reference = reference
self.reset()
def reset(self):
self.tp = None
self.fp = None
self.tn = None
self.fn = None
self.size = None
self.test_empty = None
self.test_full = None
self.reference_empty = None
self.reference_full = None
def compute(self):
if self.test is None or self.reference is None:
raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")
assert_shape(self.test, self.reference)
self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
self.size = int(np.prod(self.reference.shape, dtype=np.int64))
self.test_empty = not np.any(self.test)
self.test_full = np.all(self.test)
self.reference_empty = not np.any(self.reference)
self.reference_full = np.all(self.reference)
def get_matrix(self):
for entry in (self.tp, self.fp, self.tn, self.fn):
if entry is None:
self.compute()
break
return self.tp, self.fp, self.tn, self.fn
def get_size(self):
if self.size is None:
self.compute()
return self.size
def get_existence(self):
for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
if case is None:
self.compute()
break
return self.test_empty, self.test_full, self.reference_empty, self.reference_full
def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""2TP / (2TP + FP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty and reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(2. * tp / (2 * tp + fp + fn))
def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty and reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fp + fn))
def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FP)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fp))
def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if reference_empty:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tp / (tp + fn))
def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TP / (TP + FN)"""
return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)
def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FP)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(tn / (tn + fp))
def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
"""(TP + TN) / (TP + FP + FN + TN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return float((tp + tn) / (tp + fp + tn + fn))
def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):
"""(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""
precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)
recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)
return (1 + beta*beta) * precision_ * recall_ /\
((beta*beta * precision_) + recall_)
def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FP / (FP + TN)"""
return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)
def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FN / (TN + FN)"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0.
return float(fn / (fn + tn))
def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FN / (TP + FN)"""
return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)
def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FP)"""
return specificity(test, reference, confusion_matrix, nan_for_nonexisting)
def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""FP / (TP + FP)"""
return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)
def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
"""TN / (TN + FN)"""
return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)
def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TP + FP"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tp + fp
def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TN + FN"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tn + fn
def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TP + FN"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tp + fn
def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
"""TN + FP"""
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
tp, fp, tn, fn = confusion_matrix.get_matrix()
return tn + fp
def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.hd(test, reference, voxel_spacing, connectivity)
def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.hd95(test, reference, voxel_spacing, connectivity)
def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.asd(test, reference, voxel_spacing, connectivity)
def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix(test, reference)
test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
if test_empty or test_full or reference_empty or reference_full:
if nan_for_nonexisting:
return float("NaN")
else:
return 0
test, reference = confusion_matrix.test, confusion_matrix.reference
return metric.assd(test, reference, voxel_spacing, connectivity)
ALL_METRICS = {
"False Positive Rate": false_positive_rate,
"Dice": dice,
"Jaccard": jaccard,
"Hausdorff Distance": hausdorff_distance,
"Hausdorff Distance 95": hausdorff_distance_95,
"Precision": precision,
"Recall": recall,
"Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,
"Avg. Surface Distance": avg_surface_distance,
"Accuracy": accuracy,
"False Omission Rate": false_omission_rate,
"Negative Predictive Value": negative_predictive_value,
"False Negative Rate": false_negative_rate,
"True Negative Rate": true_negative_rate,
"False Discovery Rate": false_discovery_rate,
"Total Positives Test": total_positives_test,
"Total Negatives Test": total_negatives_test,
"Total Positives Reference": total_positives_reference,
"total Negatives Reference": total_negatives_reference
}
================================================
FILE: unetr_pp/evaluation/model_selection/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize2
from unetr_pp.paths import network_training_output_dir
from batchgenerators.utilities.file_and_folder_operations import *
if __name__ == "__main__":
summary_output_folder = join(network_training_output_dir, "summary_jsons_fold0_new")
maybe_mkdir_p(summary_output_folder)
summarize2(['all'], output_dir=summary_output_folder, folds=(0,))
results_csv = join(network_training_output_dir, "summary_fold0.csv")
summary_files = subfiles(summary_output_folder, suffix='.json', join=False)
with open(results_csv, 'w') as f:
for s in summary_files:
if s.find("ensemble") == -1:
task, network, trainer, plans, validation_folder, folds = s.split("__")
else:
n1, n2 = s.split("--")
n1 = n1[n1.find("ensemble_") + len("ensemble_") :]
task = s.split("__")[0]
network = "ensemble"
trainer = n1
plans = n2
validation_folder = "none"
folds = folds[:-len('.json')]
results = load_json(join(summary_output_folder, s))
results_mean = results['results']['mean']['mean']['Dice']
results_median = results['results']['median']['mean']['Dice']
f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task,
network, trainer, validation_folder, plans, results_mean, results_median))
summary_output_folder = join(network_training_output_dir, "summary_jsons_new")
maybe_mkdir_p(summary_output_folder)
summarize2(['all'], output_dir=summary_output_folder)
results_csv = join(network_training_output_dir, "summary_allFolds.csv")
summary_files = subfiles(summary_output_folder, suffix='.json', join=False)
with open(results_csv, 'w') as f:
for s in summary_files:
if s.find("ensemble") == -1:
task, network, trainer, plans, validation_folder, folds = s.split("__")
else:
n1, n2 = s.split("--")
n1 = n1[n1.find("ensemble_") + len("ensemble_") :]
task = s.split("__")[0]
network = "ensemble"
trainer = n1
plans = n2
validation_folder = "none"
folds = folds[:-len('.json')]
results = load_json(join(summary_output_folder, s))
results_mean = results['results']['mean']['mean']['Dice']
results_median = results['results']['median']['mean']['Dice']
f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task,
network, trainer, validation_folder, plans, results_mean, results_median))
================================================
FILE: unetr_pp/evaluation/model_selection/ensemble.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from multiprocessing.pool import Pool
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax
from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir
from unetr_pp.postprocessing.connected_components import determine_postprocessing
def merge(args):
file1, file2, properties_file, out_file = args
if not isfile(out_file):
res1 = np.load(file1)['softmax']
res2 = np.load(file2)['softmax']
props = load_pickle(properties_file)
mn = np.mean((res1, res2), 0)
# Softmax probabilities are already at target spacing so this will not do any resampling (resampling parameters
# don't matter here)
save_segmentation_nifti_from_softmax(mn, out_file, props, 3, None, None, None, force_separate_z=None,
interpolation_order_z=0)
def ensemble(training_output_folder1, training_output_folder2, output_folder, task, validation_folder, folds, allow_ensembling: bool = True):
print("\nEnsembling folders\n", training_output_folder1, "\n", training_output_folder2)
output_folder_base = output_folder
output_folder = join(output_folder_base, "ensembled_raw")
# only_keep_largest_connected_component is the same for all stages
dataset_directory = join(preprocessing_output_dir, task)
plans = load_pickle(join(training_output_folder1, "plans.pkl")) # we need this only for the labels
files1 = []
files2 = []
property_files = []
out_files = []
gt_segmentations = []
folder_with_gt_segs = join(dataset_directory, "gt_segmentations")
# in the correct shape and we need the original geometry to restore the niftis
for f in folds:
validation_folder_net1 = join(training_output_folder1, "fold_%d" % f, validation_folder)
validation_folder_net2 = join(training_output_folder2, "fold_%d" % f, validation_folder)
if not isdir(validation_folder_net1):
raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1)
if not isdir(validation_folder_net2):
raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2)
# we need to ensure the validation was successful. We can verify this via the presence of the summary.json file
if not isfile(join(validation_folder_net1, 'summary.json')):
raise AssertionError("Validation directory incomplete: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1)
if not isfile(join(validation_folder_net2, 'summary.json')):
raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2)
patient_identifiers1_npz = [i[:-4] for i in subfiles(validation_folder_net1, False, None, 'npz', True)]
patient_identifiers2_npz = [i[:-4] for i in subfiles(validation_folder_net2, False, None, 'npz', True)]
# we don't do postprocessing anymore so there should not be any of that noPostProcess
patient_identifiers1_nii = [i[:-7] for i in subfiles(validation_folder_net1, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')]
patient_identifiers2_nii = [i[:-7] for i in subfiles(validation_folder_net2, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')]
if not all([i in patient_identifiers1_npz for i in patient_identifiers1_nii]):
raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net1))
if not all([i in patient_identifiers2_npz for i in patient_identifiers2_nii]):
raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net2))
patient_identifiers1_npz.sort()
patient_identifiers2_npz.sort()
assert all([i == j for i, j in zip(patient_identifiers1_npz, patient_identifiers2_npz)]), "npz filenames do not match. This should not happen."
maybe_mkdir_p(output_folder)
for p in patient_identifiers1_npz:
files1.append(join(validation_folder_net1, p + '.npz'))
files2.append(join(validation_folder_net2, p + '.npz'))
property_files.append(join(validation_folder_net1, p) + ".pkl")
out_files.append(join(output_folder, p + ".nii.gz"))
gt_segmentations.append(join(folder_with_gt_segs, p + ".nii.gz"))
p = Pool(default_num_threads)
p.map(merge, zip(files1, files2, property_files, out_files))
p.close()
p.join()
if not isfile(join(output_folder, "summary.json")) and len(out_files) > 0:
aggregate_scores(tuple(zip(out_files, gt_segmentations)), labels=plans['all_classes'],
json_output_file=join(output_folder, "summary.json"), json_task=task,
json_name=task + "__" + output_folder_base.split("/")[-1], num_threads=default_num_threads)
if allow_ensembling and not isfile(join(output_folder_base, "postprocessing.json")):
# now lets also look at postprocessing. We cannot just take what we determined in cross-validation and apply it
# here because things may have changed and may also be too inconsistent between the two networks
determine_postprocessing(output_folder_base, folder_with_gt_segs, "ensembled_raw", "temp",
"ensembled_postprocessed", default_num_threads, dice_threshold=0)
out_dir_all_json = join(network_training_output_dir, "summary_jsons")
json_out = load_json(join(output_folder_base, "ensembled_postprocessed", "summary.json"))
json_out["experiment_name"] = output_folder_base.split("/")[-1]
save_json(json_out, join(output_folder_base, "ensembled_postprocessed", "summary.json"))
maybe_mkdir_p(out_dir_all_json)
shutil.copy(join(output_folder_base, "ensembled_postprocessed", "summary.json"),
join(out_dir_all_json, "%s__%s.json" % (task, output_folder_base.split("/")[-1])))
================================================
FILE: unetr_pp/evaluation/model_selection/figure_out_what_to_submit.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from itertools import combinations
import unetr_pp
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.evaluation.add_mean_dice_to_json import foreground_mean
from unetr_pp.evaluation.evaluator import evaluate_folder
from unetr_pp.evaluation.model_selection.ensemble import ensemble
from unetr_pp.paths import network_training_output_dir
import numpy as np
from subprocess import call
from unetr_pp.postprocessing.consolidate_postprocessing import consolidate_folds, collect_cv_niftis
from unetr_pp.utilities.folder_names import get_output_folder_name
from unetr_pp.paths import default_cascade_trainer, default_trainer, default_plans_identifier
def find_task_name(folder, task_id):
candidates = subdirs(folder, prefix="Task%03.0d_" % task_id, join=False)
assert len(candidates) > 0, "no candidate for Task id %d found in folder %s" % (task_id, folder)
assert len(candidates) == 1, "more than one candidate for Task id %d found in folder %s" % (task_id, folder)
return candidates[0]
def get_mean_foreground_dice(json_file):
results = load_json(json_file)
return get_foreground_mean(results)
def get_foreground_mean(results):
results_mean = results['results']['mean']
dice_scores = [results_mean[i]['Dice'] for i in results_mean.keys() if i != "0" and i != 'mean']
return np.mean(dice_scores)
def main():
import argparse
parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold "
"cross-validation. Running this script requires all models to have been run "
"already. This script will summarize the results of the five folds of all "
"models in one json each for easy interpretability")
parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres',
'3d_cascade_fullres'])
parser.add_argument("-t", '--task_ids', nargs="+", required=True)
parser.add_argument("-tr", type=str, required=False, default=default_trainer,
help="nnFormerTrainer class. Default: %s" % default_trainer)
parser.add_argument("-ctr", type=str, required=False, default=default_cascade_trainer,
help="nnFormerTrainer class for cascade model. Default: %s" % default_cascade_trainer)
parser.add_argument("-pl", type=str, required=False, default=default_plans_identifier,
help="plans name, Default: %s" % default_plans_identifier)
parser.add_argument('-f', '--folds', nargs='+', default=(0, 1, 2, 3, 4), help="Use this if you have non-standard "
"folds. Experienced users only.")
parser.add_argument('--disable_ensembling', required=False, default=False, action='store_true',
help='Set this flag to disable the use of ensembling. This will find the best single '
'configuration for each task.')
parser.add_argument("--disable_postprocessing", required=False, default=False, action="store_true",
help="Set this flag if you want to disable the use of postprocessing")
args = parser.parse_args()
tasks = [int(i) for i in args.task_ids]
models = args.models
tr = args.tr
trc = args.ctr
pl = args.pl
disable_ensembling = args.disable_ensembling
disable_postprocessing = args.disable_postprocessing
folds = tuple(int(i) for i in args.folds)
validation_folder = "validation_raw"
# this script now acts independently from the summary jsons. That was unnecessary
id_task_mapping = {}
for t in tasks:
# first collect pure model performance
results = {}
all_results = {}
valid_models = []
for m in models:
if m == "3d_cascade_fullres":
trainer = trc
else:
trainer = tr
if t not in id_task_mapping.keys():
task_name = find_task_name(get_output_folder_name(m), t)
id_task_mapping[t] = task_name
output_folder = get_output_folder_name(m, id_task_mapping[t], trainer, pl)
if not isdir(output_folder):
raise RuntimeError("Output folder for model %s is missing, expected: %s" % (m, output_folder))
if disable_postprocessing:
# we need to collect the predicted niftis from the 5-fold cv and evaluate them against the ground truth
cv_niftis_folder = join(output_folder, 'cv_niftis_raw')
if not isfile(join(cv_niftis_folder, 'summary.json')):
print(t, m, ': collecting niftis from 5-fold cv')
if isdir(cv_niftis_folder):
shutil.rmtree(cv_niftis_folder)
collect_cv_niftis(output_folder, cv_niftis_folder, validation_folder, folds)
niftis_gt = subfiles(join(output_folder, "gt_niftis"), suffix='.nii.gz', join=False)
niftis_cv = subfiles(cv_niftis_folder, suffix='.nii.gz', join=False)
if not all([i in niftis_gt for i in niftis_cv]):
raise AssertionError("It does not seem like you trained all the folds! Train " \
"all folds first! There are %d gt niftis in %s but only " \
"%d predicted niftis in %s" % (len(niftis_gt), niftis_gt,
len(niftis_cv), niftis_cv))
# load a summary file so that we can know what class labels to expect
summary_fold0 = load_json(join(output_folder, "fold_%d" % folds[0], validation_folder,
"summary.json"))['results']['mean']
# read classes from summary.json
classes = tuple((int(i) for i in summary_fold0.keys()))
# evaluate the cv niftis
print(t, m, ': evaluating 5-fold cv results')
evaluate_folder(join(output_folder, "gt_niftis"), cv_niftis_folder, classes)
else:
postprocessing_json = join(output_folder, "postprocessing.json")
cv_niftis_folder = join(output_folder, "cv_niftis_raw")
# we need cv_niftis_postprocessed to know the single model performance. And we need the
# postprocessing_json. If either of those is missing, rerun consolidate_folds
if not isfile(postprocessing_json) or not isdir(cv_niftis_folder):
print("running missing postprocessing for %s and model %s" % (id_task_mapping[t], m))
consolidate_folds(output_folder, folds=folds)
assert isfile(postprocessing_json), "Postprocessing json missing, expected: %s" % postprocessing_json
assert isdir(cv_niftis_folder), "Folder with niftis from CV missing, expected: %s" % cv_niftis_folder
# obtain mean foreground dice
summary_file = join(cv_niftis_folder, "summary.json")
results[m] = get_mean_foreground_dice(summary_file)
foreground_mean(summary_file)
all_results[m] = load_json(summary_file)['results']['mean']
valid_models.append(m)
if not disable_ensembling:
# now run ensembling and add ensembling to results
print("\nI will now ensemble combinations of the following models:\n", valid_models)
if len(valid_models) > 1:
for m1, m2 in combinations(valid_models, 2):
trainer_m1 = trc if m1 == "3d_cascade_fullres" else tr
trainer_m2 = trc if m2 == "3d_cascade_fullres" else tr
ensemble_name = "ensemble_" + m1 + "__" + trainer_m1 + "__" + pl + "--" + m2 + "__" + trainer_m2 + "__" + pl
output_folder_base = join(network_training_output_dir, "ensembles", id_task_mapping[t], ensemble_name)
maybe_mkdir_p(output_folder_base)
network1_folder = get_output_folder_name(m1, id_task_mapping[t], trainer_m1, pl)
network2_folder = get_output_folder_name(m2, id_task_mapping[t], trainer_m2, pl)
print("ensembling", network1_folder, network2_folder)
ensemble(network1_folder, network2_folder, output_folder_base, id_task_mapping[t], validation_folder, folds, allow_ensembling=not disable_postprocessing)
# ensembling will automatically do postprocessingget_foreground_mean
# now get result of ensemble
results[ensemble_name] = get_mean_foreground_dice(join(output_folder_base, "ensembled_raw", "summary.json"))
summary_file = join(output_folder_base, "ensembled_raw", "summary.json")
foreground_mean(summary_file)
all_results[ensemble_name] = load_json(summary_file)['results']['mean']
# now print all mean foreground dice and highlight the best
foreground_dices = list(results.values())
best = np.max(foreground_dices)
for k, v in results.items():
print(k, v)
predict_str = ""
best_model = None
for k, v in results.items():
if v == best:
print("%s submit model %s" % (id_task_mapping[t], k), v)
best_model = k
print("\nHere is how you should predict test cases. Run in sequential order and replace all input and output folder names with your personalized ones\n")
if k.startswith("ensemble"):
tmp = k[len("ensemble_"):]
model1, model2 = tmp.split("--")
m1, t1, pl1 = model1.split("__")
m2, t2, pl2 = model2.split("__")
predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + m1 + " -p " + pl + " -t " + \
id_task_mapping[t] + "\n"
predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL2 -tr " + tr + " -ctr " + trc + " -m " + m2 + " -p " + pl + " -t " + \
id_task_mapping[t] + "\n"
if not disable_postprocessing:
predict_str += "nnFormer_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER -pp " + join(network_training_output_dir, "ensembles", id_task_mapping[t], k, "postprocessing.json") + "\n"
else:
predict_str += "nnFormer_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER\n"
else:
predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + k + " -p " + pl + " -t " + \
id_task_mapping[t] + "\n"
print(predict_str)
summary_folder = join(network_training_output_dir, "ensembles", id_task_mapping[t])
maybe_mkdir_p(summary_folder)
with open(join(summary_folder, "prediction_commands.txt"), 'w') as f:
f.write(predict_str)
num_classes = len([i for i in all_results[best_model].keys() if i != 'mean' and i != '0'])
with open(join(summary_folder, "summary.csv"), 'w') as f:
f.write("model")
for c in range(1, num_classes + 1):
f.write(",class%d" % c)
f.write(",average")
f.write("\n")
for m in all_results.keys():
f.write(m)
for c in range(1, num_classes + 1):
f.write(",%01.4f" % all_results[m][str(c)]["Dice"])
f.write(",%01.4f" % all_results[m]['mean']["Dice"])
f.write("\n")
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/evaluation/model_selection/rank_candidates.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import network_training_output_dir
if __name__ == "__main__":
# run collect_all_fold0_results_and_summarize_in_one_csv.py first
summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new")
output_file = join(network_training_output_dir, "summary.csv")
folds = (0, )
folds_str = ""
for f in folds:
folds_str += str(f)
plans = "nnFormerPlans"
overwrite_plans = {
'nnFormerTrainerV2_2': ["nnFormerPlans", "nnFormerPlansisoPatchesInVoxels"], # r
'nnFormerTrainerV2': ["nnFormerPlansnonCT", "nnFormerPlansCT2", "nnFormerPlansallConv3x3",
"nnFormerPlansfixedisoPatchesInVoxels", "nnFormerPlanstargetSpacingForAnisoAxis",
"nnFormerPlanspoolBasedOnSpacing", "nnFormerPlansfixedisoPatchesInmm", "nnFormerPlansv2.1"],
'nnFormerTrainerV2_warmup': ["nnFormerPlans", "nnFormerPlansv2.1", "nnFormerPlansv2.1_big", "nnFormerPlansv2.1_verybig"],
'nnFormerTrainerV2_cycleAtEnd': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_cycleAtEnd2': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_reduceMomentumDuringTraining': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_graduallyTransitionFromCEToDice': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_independentScalePerAxis': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_Mish': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_Ranger_lr3en4': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_GN': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_momentum098': ["nnFormerPlans", "nnFormerPlansv2.1"],
'nnFormerTrainerV2_momentum09': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_DP': ["nnFormerPlansv2.1_verybig"],
'nnFormerTrainerV2_DDP': ["nnFormerPlansv2.1_verybig"],
'nnFormerTrainerV2_FRN': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_resample33': ["nnFormerPlansv2.3"],
'nnFormerTrainerV2_O2': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ResencUNet': ["nnFormerPlans_FabiansResUNet_v2.1"],
'nnFormerTrainerV2_DA2': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_allConv3x3': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ForceBD': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ForceSD': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_LReLU_slope_2en1': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_lReLU_convReLUIN': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ReLU': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ReLU_biasInSegOutput': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_ReLU_convReLUIN': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_lReLU_biasInSegOutput': ["nnFormerPlansv2.1"],
#'nnFormerTrainerV2_Loss_MCC': ["nnFormerPlansv2.1"],
#'nnFormerTrainerV2_Loss_MCCnoBG': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_Loss_DicewithBG': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_Loss_Dice_LR1en3': ["nnFormerPlansv2.1"],
'nnFormerTrainerV2_Loss_Dice': ["nnFormerPlans", "nnFormerPlansv2.1"],
'nnFormerTrainerV2_Loss_DicewithBG_LR1en3': ["nnFormerPlansv2.1"],
# 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
# 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
# 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
# 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
# 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"],
}
trainers = ['nnFormerTrainer'] + ['nnFormerTrainerNewCandidate%d' % i for i in range(1, 28)] + [
'nnFormerTrainerNewCandidate24_2',
'nnFormerTrainerNewCandidate24_3',
'nnFormerTrainerNewCandidate26_2',
'nnFormerTrainerNewCandidate27_2',
'nnFormerTrainerNewCandidate23_always3DDA',
'nnFormerTrainerNewCandidate23_corrInit',
'nnFormerTrainerNewCandidate23_noOversampling',
'nnFormerTrainerNewCandidate23_softDS',
'nnFormerTrainerNewCandidate23_softDS2',
'nnFormerTrainerNewCandidate23_softDS3',
'nnFormerTrainerNewCandidate23_softDS4',
'nnFormerTrainerNewCandidate23_2_fp16',
'nnFormerTrainerNewCandidate23_2',
'nnFormerTrainerVer2',
'nnFormerTrainerV2_2',
'nnFormerTrainerV2_3',
'nnFormerTrainerV2_3_CE_GDL',
'nnFormerTrainerV2_3_dcTopk10',
'nnFormerTrainerV2_3_dcTopk20',
'nnFormerTrainerV2_3_fp16',
'nnFormerTrainerV2_3_softDS4',
'nnFormerTrainerV2_3_softDS4_clean',
'nnFormerTrainerV2_3_softDS4_clean_improvedDA',
'nnFormerTrainerV2_3_softDS4_clean_improvedDA_newElDef',
'nnFormerTrainerV2_3_softDS4_radam',
'nnFormerTrainerV2_3_softDS4_radam_lowerLR',
'nnFormerTrainerV2_2_schedule',
'nnFormerTrainerV2_2_schedule2',
'nnFormerTrainerV2_2_clean',
'nnFormerTrainerV2_2_clean_improvedDA_newElDef',
'nnFormerTrainerV2_2_fixes', # running
'nnFormerTrainerV2_BN', # running
'nnFormerTrainerV2_noDeepSupervision', # running
'nnFormerTrainerV2_softDeepSupervision', # running
'nnFormerTrainerV2_noDataAugmentation', # running
'nnFormerTrainerV2_Loss_CE', # running
'nnFormerTrainerV2_Loss_CEGDL',
'nnFormerTrainerV2_Loss_Dice',
'nnFormerTrainerV2_Loss_DiceTopK10',
'nnFormerTrainerV2_Loss_TopK10',
'nnFormerTrainerV2_Adam', # running
'nnFormerTrainerV2_Adam_nnFormerTrainerlr', # running
'nnFormerTrainerV2_SGD_ReduceOnPlateau', # running
'nnFormerTrainerV2_SGD_lr1en1', # running
'nnFormerTrainerV2_SGD_lr1en3', # running
'nnFormerTrainerV2_fixedNonlin', # running
'nnFormerTrainerV2_GeLU', # running
'nnFormerTrainerV2_3ConvPerStage',
'nnFormerTrainerV2_NoNormalization',
'nnFormerTrainerV2_Adam_ReduceOnPlateau',
'nnFormerTrainerV2_fp16',
'nnFormerTrainerV2', # see overwrite_plans
'nnFormerTrainerV2_noMirroring',
'nnFormerTrainerV2_momentum09',
'nnFormerTrainerV2_momentum095',
'nnFormerTrainerV2_momentum098',
'nnFormerTrainerV2_warmup',
'nnFormerTrainerV2_Loss_Dice_LR1en3',
'nnFormerTrainerV2_NoNormalization_lr1en3',
'nnFormerTrainerV2_Loss_Dice_squared',
'nnFormerTrainerV2_newElDef',
'nnFormerTrainerV2_fp32',
'nnFormerTrainerV2_cycleAtEnd',
'nnFormerTrainerV2_reduceMomentumDuringTraining',
'nnFormerTrainerV2_graduallyTransitionFromCEToDice',
'nnFormerTrainerV2_insaneDA',
'nnFormerTrainerV2_independentScalePerAxis',
'nnFormerTrainerV2_Mish',
'nnFormerTrainerV2_Ranger_lr3en4',
'nnFormerTrainerV2_cycleAtEnd2',
'nnFormerTrainerV2_GN',
'nnFormerTrainerV2_DP',
'nnFormerTrainerV2_FRN',
'nnFormerTrainerV2_resample33',
'nnFormerTrainerV2_O2',
'nnFormerTrainerV2_ResencUNet',
'nnFormerTrainerV2_DA2',
'nnFormerTrainerV2_allConv3x3',
'nnFormerTrainerV2_ForceBD',
'nnFormerTrainerV2_ForceSD',
'nnFormerTrainerV2_ReLU',
'nnFormerTrainerV2_LReLU_slope_2en1',
'nnFormerTrainerV2_lReLU_convReLUIN',
'nnFormerTrainerV2_ReLU_biasInSegOutput',
'nnFormerTrainerV2_ReLU_convReLUIN',
'nnFormerTrainerV2_lReLU_biasInSegOutput',
'nnFormerTrainerV2_Loss_DicewithBG_LR1en3',
#'nnFormerTrainerV2_Loss_MCCnoBG',
'nnFormerTrainerV2_Loss_DicewithBG',
# 'nnFormerTrainerV2_Loss_Dice_LR1en3',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
# 'nnFormerTrainerV2_Ranger_lr3en4',
]
datasets = \
{"Task001_BrainTumour": ("3d_fullres", ),
"Task002_Heart": ("3d_fullres",),
#"Task024_Promise": ("3d_fullres",),
#"Task027_ACDC": ("3d_fullres",),
"Task003_Liver": ("3d_fullres", "3d_lowres"),
"Task004_Hippocampus": ("3d_fullres",),
"Task005_Prostate": ("3d_fullres",),
"Task006_Lung": ("3d_fullres", "3d_lowres"),
"Task007_Pancreas": ("3d_fullres", "3d_lowres"),
"Task008_HepaticVessel": ("3d_fullres", "3d_lowres"),
"Task009_Spleen": ("3d_fullres", "3d_lowres"),
"Task010_Colon": ("3d_fullres", "3d_lowres"),}
expected_validation_folder = "validation_raw"
alternative_validation_folder = "validation"
alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
interested_in = "mean"
result_per_dataset = {}
for d in datasets:
result_per_dataset[d] = {}
for c in datasets[d]:
result_per_dataset[d][c] = []
valid_trainers = []
all_trainers = []
with open(output_file, 'w') as f:
f.write("trainer,")
for t in datasets.keys():
s = t[4:7]
for c in datasets[t]:
s1 = s + "_" + c[3]
f.write("%s," % s1)
f.write("\n")
for trainer in trainers:
trainer_plans = [plans]
if trainer in overwrite_plans.keys():
trainer_plans = overwrite_plans[trainer]
result_per_dataset_here = {}
for d in datasets:
result_per_dataset_here[d] = {}
for p in trainer_plans:
name = "%s__%s" % (trainer, p)
all_present = True
all_trainers.append(name)
f.write("%s," % name)
for dataset in datasets.keys():
for configuration in datasets[dataset]:
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
if not isfile(summary_file):
all_present = False
print(name, dataset, configuration, "has missing summary file")
if isfile(summary_file):
result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
result_per_dataset_here[dataset][configuration] = result
f.write("%02.4f," % result)
else:
f.write("NA,")
result_per_dataset_here[dataset][configuration] = 0
f.write("\n")
if True:
valid_trainers.append(name)
for d in datasets:
for c in datasets[d]:
result_per_dataset[d][c].append(result_per_dataset_here[d][c])
invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
num_valid = len(valid_trainers)
num_datasets = len(datasets.keys())
# create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
all_res = np.zeros((num_valid, num_datasets))
for j, d in enumerate(datasets.keys()):
ks = list(result_per_dataset[d].keys())
tmp = result_per_dataset[d][ks[0]]
for k in ks[1:]:
for i in range(len(tmp)):
tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
all_res[:, j] = tmp
ranks_arr = np.zeros_like(all_res)
for d in range(ranks_arr.shape[1]):
temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
ranks = np.empty_like(temp)
ranks[temp] = np.arange(len(temp))
ranks_arr[:, d] = ranks
mn = np.mean(ranks_arr, 1)
for i in np.argsort(mn):
print(mn[i], valid_trainers[i])
print()
print(valid_trainers[np.argmin(mn)])
================================================
FILE: unetr_pp/evaluation/model_selection/rank_candidates_StructSeg.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import network_training_output_dir
if __name__ == "__main__":
# run collect_all_fold0_results_and_summarize_in_one_csv.py first
summary_files_dir = join(network_training_output_dir, "summary_jsons_new")
output_file = join(network_training_output_dir, "summary_structseg_5folds.csv")
folds = (0, 1, 2, 3, 4)
folds_str = ""
for f in folds:
folds_str += str(f)
plans = "nnFormerPlans"
overwrite_plans = {
'nnFormerTrainerV2_2': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_2_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_lessMomentum_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_2_structSeg_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_2_structSeg': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_lessMomentum_noMirror_structSeg': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
'nnFormerTrainerV2_FabiansResUNet_structSet': ["nnFormerPlans", "nnFormerPlans_customClip"], # r
}
trainers = ['nnFormerTrainer'] + [
'nnFormerTrainerV2_2',
'nnFormerTrainerV2_lessMomentum_noMirror',
'nnFormerTrainerV2_2_noMirror',
'nnFormerTrainerV2_2_structSeg_noMirror',
'nnFormerTrainerV2_2_structSeg',
'nnFormerTrainerV2_lessMomentum_noMirror_structSeg',
'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder',
'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror',
'nnFormerTrainerV2_FabiansResUNet_structSet',
]
datasets = \
{"Task049_StructSeg2019_Task1_HaN_OAR": ("3d_fullres", "3d_lowres", "2d"),
"Task050_StructSeg2019_Task2_Naso_GTV": ("3d_fullres", "3d_lowres", "2d"),
"Task051_StructSeg2019_Task3_Thoracic_OAR": ("3d_fullres", "3d_lowres", "2d"),
"Task052_StructSeg2019_Task4_Lung_GTV": ("3d_fullres", "3d_lowres", "2d"),
}
expected_validation_folder = "validation_raw"
alternative_validation_folder = "validation"
alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
interested_in = "mean"
result_per_dataset = {}
for d in datasets:
result_per_dataset[d] = {}
for c in datasets[d]:
result_per_dataset[d][c] = []
valid_trainers = []
all_trainers = []
with open(output_file, 'w') as f:
f.write("trainer,")
for t in datasets.keys():
s = t[4:7]
for c in datasets[t]:
if len(c) > 3:
n = c[3]
else:
n = "2"
s1 = s + "_" + n
f.write("%s," % s1)
f.write("\n")
for trainer in trainers:
trainer_plans = [plans]
if trainer in overwrite_plans.keys():
trainer_plans = overwrite_plans[trainer]
result_per_dataset_here = {}
for d in datasets:
result_per_dataset_here[d] = {}
for p in trainer_plans:
name = "%s__%s" % (trainer, p)
all_present = True
all_trainers.append(name)
f.write("%s," % name)
for dataset in datasets.keys():
for configuration in datasets[dataset]:
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
if not isfile(summary_file):
all_present = False
print(name, dataset, configuration, "has missing summary file")
if isfile(summary_file):
result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
result_per_dataset_here[dataset][configuration] = result
f.write("%02.4f," % result)
else:
f.write("NA,")
f.write("\n")
if all_present:
valid_trainers.append(name)
for d in datasets:
for c in datasets[d]:
result_per_dataset[d][c].append(result_per_dataset_here[d][c])
invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
num_valid = len(valid_trainers)
num_datasets = len(datasets.keys())
# create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
all_res = np.zeros((num_valid, num_datasets))
for j, d in enumerate(datasets.keys()):
ks = list(result_per_dataset[d].keys())
tmp = result_per_dataset[d][ks[0]]
for k in ks[1:]:
for i in range(len(tmp)):
tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
all_res[:, j] = tmp
ranks_arr = np.zeros_like(all_res)
for d in range(ranks_arr.shape[1]):
temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
ranks = np.empty_like(temp)
ranks[temp] = np.arange(len(temp))
ranks_arr[:, d] = ranks
mn = np.mean(ranks_arr, 1)
for i in np.argsort(mn):
print(mn[i], valid_trainers[i])
print()
print(valid_trainers[np.argmin(mn)])
================================================
FILE: unetr_pp/evaluation/model_selection/rank_candidates_cascade.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import network_training_output_dir
if __name__ == "__main__":
# run collect_all_fold0_results_and_summarize_in_one_csv.py first
summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new")
output_file = join(network_training_output_dir, "summary_cascade.csv")
folds = (0, )
folds_str = ""
for f in folds:
folds_str += str(f)
plans = "nnFormerPlansv2.1"
overwrite_plans = {
'nnFormerTrainerCascadeFullRes': ['nnFormerPlans'],
}
trainers = [
'nnFormerTrainerCascadeFullRes',
'nnFormerTrainerV2CascadeFullRes_EducatedGuess',
'nnFormerTrainerV2CascadeFullRes_EducatedGuess2',
'nnFormerTrainerV2CascadeFullRes_EducatedGuess3',
'nnFormerTrainerV2CascadeFullRes_lowerLR',
'nnFormerTrainerV2CascadeFullRes',
'nnFormerTrainerV2CascadeFullRes_noConnComp',
'nnFormerTrainerV2CascadeFullRes_shorter_lowerLR',
'nnFormerTrainerV2CascadeFullRes_shorter',
'nnFormerTrainerV2CascadeFullRes_smallerBinStrel',
#'',
#'',
#'',
#'',
#'',
#'',
]
datasets = \
{
"Task003_Liver": ("3d_cascade_fullres", ),
"Task006_Lung": ("3d_cascade_fullres", ),
"Task007_Pancreas": ("3d_cascade_fullres", ),
"Task008_HepaticVessel": ("3d_cascade_fullres", ),
"Task009_Spleen": ("3d_cascade_fullres", ),
"Task010_Colon": ("3d_cascade_fullres", ),
"Task017_AbdominalOrganSegmentation": ("3d_cascade_fullres", ),
#"Task029_LITS": ("3d_cascade_fullres", ),
"Task048_KiTS_clean": ("3d_cascade_fullres", ),
"Task055_SegTHOR": ("3d_cascade_fullres", ),
"Task056_VerSe": ("3d_cascade_fullres", ),
#"": ("3d_cascade_fullres", ),
}
expected_validation_folder = "validation_raw"
alternative_validation_folder = "validation"
alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
interested_in = "mean"
result_per_dataset = {}
for d in datasets:
result_per_dataset[d] = {}
for c in datasets[d]:
result_per_dataset[d][c] = []
valid_trainers = []
all_trainers = []
with open(output_file, 'w') as f:
f.write("trainer,")
for t in datasets.keys():
s = t[4:7]
for c in datasets[t]:
s1 = s + "_" + c[3]
f.write("%s," % s1)
f.write("\n")
for trainer in trainers:
trainer_plans = [plans]
if trainer in overwrite_plans.keys():
trainer_plans = overwrite_plans[trainer]
result_per_dataset_here = {}
for d in datasets:
result_per_dataset_here[d] = {}
for p in trainer_plans:
name = "%s__%s" % (trainer, p)
all_present = True
all_trainers.append(name)
f.write("%s," % name)
for dataset in datasets.keys():
for configuration in datasets[dataset]:
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
if not isfile(summary_file):
summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
if not isfile(summary_file):
all_present = False
print(name, dataset, configuration, "has missing summary file")
if isfile(summary_file):
result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
result_per_dataset_here[dataset][configuration] = result
f.write("%02.4f," % result)
else:
f.write("NA,")
result_per_dataset_here[dataset][configuration] = 0
f.write("\n")
if True:
valid_trainers.append(name)
for d in datasets:
for c in datasets[d]:
result_per_dataset[d][c].append(result_per_dataset_here[d][c])
invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
num_valid = len(valid_trainers)
num_datasets = len(datasets.keys())
# create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
all_res = np.zeros((num_valid, num_datasets))
for j, d in enumerate(datasets.keys()):
ks = list(result_per_dataset[d].keys())
tmp = result_per_dataset[d][ks[0]]
for k in ks[1:]:
for i in range(len(tmp)):
tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
all_res[:, j] = tmp
ranks_arr = np.zeros_like(all_res)
for d in range(ranks_arr.shape[1]):
temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
ranks = np.empty_like(temp)
ranks[temp] = np.arange(len(temp))
ranks_arr[:, d] = ranks
mn = np.mean(ranks_arr, 1)
for i in np.argsort(mn):
print(mn[i], valid_trainers[i])
print()
print(valid_trainers[np.argmin(mn)])
================================================
FILE: unetr_pp/evaluation/model_selection/summarize_results_in_one_json.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from unetr_pp.evaluation.add_mean_dice_to_json import foreground_mean
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import network_training_output_dir
import numpy as np
def summarize(tasks, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'),
output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)):
maybe_mkdir_p(output_dir)
if len(tasks) == 1 and tasks[0] == "all":
tasks = list(range(999))
else:
tasks = [int(i) for i in tasks]
for model in models:
for t in tasks:
t = int(t)
if not isdir(join(network_training_output_dir, model)):
continue
task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False)
if len(task_name) != 1:
print("did not find unique output folder for network %s and task %s" % (model, t))
continue
task_name = task_name[0]
out_dir_task = join(network_training_output_dir, model, task_name)
model_trainers = subdirs(out_dir_task, join=False)
for trainer in model_trainers:
if trainer.startswith("fold"):
continue
out_dir = join(out_dir_task, trainer)
validation_folders = []
for fld in folds:
d = join(out_dir, "fold%d"%fld)
if not isdir(d):
d = join(out_dir, "fold_%d"%fld)
if not isdir(d):
break
validation_folders += subfolders(d, prefix="validation", join=False)
for v in validation_folders:
ok = True
metrics = OrderedDict()
for fld in folds:
d = join(out_dir, "fold%d"%fld)
if not isdir(d):
d = join(out_dir, "fold_%d"%fld)
if not isdir(d):
ok = False
break
validation_folder = join(d, v)
if not isfile(join(validation_folder, "summary.json")):
print("summary.json missing for net %s task %s fold %d" % (model, task_name, fld))
ok = False
break
metrics_tmp = load_json(join(validation_folder, "summary.json"))["results"]["mean"]
for l in metrics_tmp.keys():
if metrics.get(l) is None:
metrics[l] = OrderedDict()
for m in metrics_tmp[l].keys():
if metrics[l].get(m) is None:
metrics[l][m] = []
metrics[l][m].append(metrics_tmp[l][m])
if ok:
for l in metrics.keys():
for m in metrics[l].keys():
assert len(metrics[l][m]) == len(folds)
metrics[l][m] = np.mean(metrics[l][m])
json_out = OrderedDict()
json_out["results"] = OrderedDict()
json_out["results"]["mean"] = metrics
json_out["task"] = task_name
json_out["description"] = model + " " + task_name + " all folds summary"
json_out["name"] = model + " " + task_name + " all folds summary"
json_out["experiment_name"] = model
save_json(json_out, join(out_dir, "summary_allFolds__%s.json" % v))
save_json(json_out, join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v)))
foreground_mean(join(out_dir, "summary_allFolds__%s.json" % v))
foreground_mean(join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v)))
def summarize2(task_ids, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'),
output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)):
maybe_mkdir_p(output_dir)
if len(task_ids) == 1 and task_ids[0] == "all":
task_ids = list(range(999))
else:
task_ids = [int(i) for i in task_ids]
for model in models:
for t in task_ids:
if not isdir(join(network_training_output_dir, model)):
continue
task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False)
if len(task_name) != 1:
print("did not find unique output folder for network %s and task %s" % (model, t))
continue
task_name = task_name[0]
out_dir_task = join(network_training_output_dir, model, task_name)
model_trainers = subdirs(out_dir_task, join=False)
for trainer in model_trainers:
if trainer.startswith("fold"):
continue
out_dir = join(out_dir_task, trainer)
validation_folders = []
for fld in folds:
fold_output_dir = join(out_dir, "fold_%d"%fld)
if not isdir(fold_output_dir):
continue
validation_folders += subfolders(fold_output_dir, prefix="validation", join=False)
validation_folders = np.unique(validation_folders)
for v in validation_folders:
ok = True
metrics = OrderedDict()
metrics['mean'] = OrderedDict()
metrics['median'] = OrderedDict()
metrics['all'] = OrderedDict()
for fld in folds:
fold_output_dir = join(out_dir, "fold_%d"%fld)
if not isdir(fold_output_dir):
print("fold missing", model, task_name, trainer, fld)
ok = False
break
validation_folder = join(fold_output_dir, v)
if not isdir(validation_folder):
print("validation folder missing", model, task_name, trainer, fld, v)
ok = False
break
if not isfile(join(validation_folder, "summary.json")):
print("summary.json missing", model, task_name, trainer, fld, v)
ok = False
break
all_metrics = load_json(join(validation_folder, "summary.json"))["results"]
# we now need to get the mean and median metrics. We use the mean metrics just to get the
# names of computed metics, we ignore the precomputed mean and do it ourselfes again
mean_metrics = all_metrics["mean"]
all_labels = [i for i in list(mean_metrics.keys()) if i != "mean"]
if len(all_labels) == 0: print(v, fld); break
all_metrics_names = list(mean_metrics[all_labels[0]].keys())
for l in all_labels:
# initialize the data structure, no values are copied yet
for k in ['mean', 'median', 'all']:
if metrics[k].get(l) is None:
metrics[k][l] = OrderedDict()
for m in all_metrics_names:
if metrics['all'][l].get(m) is None:
metrics['all'][l][m] = []
for entry in all_metrics['all']:
for l in all_labels:
for m in all_metrics_names:
metrics['all'][l][m].append(entry[l][m])
# now compute mean and median
for l in metrics['all'].keys():
for m in metrics['all'][l].keys():
metrics['mean'][l][m] = np.nanmean(metrics['all'][l][m])
metrics['median'][l][m] = np.nanmedian(metrics['all'][l][m])
if ok:
fold_string = ""
for f in folds:
fold_string += str(f)
json_out = OrderedDict()
json_out["results"] = OrderedDict()
json_out["results"]["mean"] = metrics['mean']
json_out["results"]["median"] = metrics['median']
json_out["task"] = task_name
json_out["description"] = model + " " + task_name + "summary folds" + str(folds)
json_out["name"] = model + " " + task_name + "summary folds" + str(folds)
json_out["experiment_name"] = model
save_json(json_out, join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string)))
foreground_mean2(join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string)))
def foreground_mean2(filename):
with open(filename, 'r') as f:
res = json.load(f)
class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean') and i != '0'])
metric_names = res['results']['mean']['1'].keys()
res['results']['mean']["mean"] = OrderedDict()
res['results']['median']["mean"] = OrderedDict()
for m in metric_names:
foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids]
res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
foreground_values = [res['results']['median'][str(i)][m] for i in class_ids]
res['results']['median']["mean"][m] = np.nanmean(foreground_values)
with open(filename, 'w') as f:
json.dump(res, f, indent=4, sort_keys=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold "
"cross-validation. Running this script requires alle models to have been run "
"already. This script will summarize the results of the five folds of all "
"models in one json each for easy interpretability")
parser.add_argument("-t", '--task_ids', nargs="+", required=True, help="task id. can be 'all'")
parser.add_argument("-f", '--folds', nargs="+", required=False, type=int, default=[0, 1, 2, 3, 4])
parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'])
args = parser.parse_args()
tasks = args.task_ids
models = args.models
folds = args.folds
summarize2(tasks, models, folds=folds, output_dir=join(network_training_output_dir, "summary_jsons_new"))
================================================
FILE: unetr_pp/evaluation/model_selection/summarize_results_with_plans.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
import os
from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize
from unetr_pp.paths import network_training_output_dir
import numpy as np
def list_to_string(l, delim=","):
st = "%03.3f" % l[0]
for i in l[1:]:
st += delim + "%03.3f" % i
return st
def write_plans_to_file(f, plans_file, stage=0, do_linebreak_at_end=True, override_name=None):
a = load_pickle(plans_file)
stages = list(a['plans_per_stage'].keys())
stages.sort()
patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'],
a['plans_per_stage'][stages[stage]]['current_spacing'])]
median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'],
a['plans_per_stage'][stages[stage]]['current_spacing'])]
if override_name is None:
f.write(plans_file.split("/")[-2] + "__" + plans_file.split("/")[-1])
else:
f.write(override_name)
f.write(";%d" % stage)
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size']))
f.write(";%s" % list_to_string(patch_size_in_mm))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels']))
f.write(";%s" % list_to_string(median_patient_size_in_mm))
f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['current_spacing']))
f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['original_spacing']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes']))
if do_linebreak_at_end:
f.write("\n")
if __name__ == "__main__":
summarize((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 24, 27), output_dir=join(network_training_output_dir, "summary_fold0"), folds=(0,))
base_dir = os.environ['RESULTS_FOLDER']
nnformers = ['nnFormerV2', 'nnFormerV2_zspacing']
task_ids = list(range(99))
with open("summary.csv", 'w') as f:
f.write("identifier;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes;patient_dc;global_dc\n")
for i in task_ids:
for nnformer in nnformers:
try:
summary_folder = join(base_dir, nnformer, "summary_fold0")
if isdir(summary_folder):
summary_files = subfiles(summary_folder, join=False, prefix="Task%03.0d_" % i, suffix=".json", sort=True)
for s in summary_files:
tmp = s.split("__")
trainer = tmp[2]
expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0])
name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2].split(".")[0]
global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json")
if not isdir(expected_output_folder) or len(tmp) > 3:
if len(tmp) == 2:
continue
expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0])
name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2] + "__" + tmp[3].split(".")[0]
global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json")
assert isdir(expected_output_folder), "expected output dir not found"
plans_file = join(expected_output_folder, "plans.pkl")
assert isfile(plans_file)
plans = load_pickle(plans_file)
num_stages = len(plans['plans_per_stage'])
if num_stages > 1 and tmp[1] == "3d_fullres":
stage = 1
elif (num_stages == 1 and tmp[1] == "3d_fullres") or tmp[1] == "3d_lowres":
stage = 0
else:
print("skipping", s)
continue
g_dc = load_json(global_dice_json)
mn_glob_dc = np.mean(list(g_dc.values()))
write_plans_to_file(f, plans_file, stage, False, name)
# now read and add result to end of line
results = load_json(join(summary_folder, s))
mean_dc = results['results']['mean']['mean']['Dice']
f.write(";%03.3f" % mean_dc)
f.write(";%03.3f\n" % mn_glob_dc)
print(name, mean_dc)
except Exception as e:
print(e)
================================================
FILE: unetr_pp/evaluation/region_based_evaluation.py
================================================
from copy import deepcopy
from multiprocessing.pool import Pool
from batchgenerators.utilities.file_and_folder_operations import *
from medpy import metric
import SimpleITK as sitk
import numpy as np
from unetr_pp.configuration import default_num_threads
from unetr_pp.postprocessing.consolidate_postprocessing import collect_cv_niftis
def get_brats_regions():
"""
this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a
different labeling convention!
:return:
"""
regions = {
"whole tumor": (1, 2, 3),
"tumor core": (2, 3),
"enhancing tumor": (3,)
}
return regions
def get_KiTS_regions():
regions = {
"kidney incl tumor": (1, 2),
"tumor": (2,)
}
return regions
def create_region_from_mask(mask, join_labels: tuple):
mask_new = np.zeros_like(mask, dtype=np.uint8)
for l in join_labels:
mask_new[mask == l] = 1
return mask_new
def evaluate_case(file_pred: str, file_gt: str, regions):
image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt))
image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred))
results = []
for r in regions:
mask_pred = create_region_from_mask(image_pred, r)
mask_gt = create_region_from_mask(image_gt, r)
dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt)
results.append(dc)
return results
def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads):
region_names = list(regions.keys())
files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False)
files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False)
have_no_gt = [i for i in files_in_pred if i not in files_in_gt]
assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt"
have_no_pred = [i for i in files_in_gt if i not in files_in_pred]
if len(have_no_pred) > 0:
print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!")
files_in_gt.sort()
files_in_pred.sort()
# run for all cases
full_filenames_gt = [join(folder_gt, i) for i in files_in_pred]
full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred]
p = Pool(processes)
res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt)))
p.close()
p.join()
all_results = {r: [] for r in region_names}
with open(join(folder_predicted, 'summary.csv'), 'w') as f:
f.write("casename")
for r in region_names:
f.write(",%s" % r)
f.write("\n")
for i in range(len(files_in_pred)):
f.write(files_in_pred[i][:-7])
result_here = res[i]
for k, r in enumerate(region_names):
dc = result_here[k]
f.write(",%02.4f" % dc)
all_results[r].append(dc)
f.write("\n")
f.write('mean')
for r in region_names:
f.write(",%02.4f" % np.nanmean(all_results[r]))
f.write("\n")
f.write('median')
for r in region_names:
f.write(",%02.4f" % np.nanmedian(all_results[r]))
f.write("\n")
f.write('mean (nan is 1)')
for r in region_names:
tmp = np.array(all_results[r])
tmp[np.isnan(tmp)] = 1
f.write(",%02.4f" % np.mean(tmp))
f.write("\n")
f.write('median (nan is 1)')
for r in region_names:
tmp = np.array(all_results[r])
tmp[np.isnan(tmp)] = 1
f.write(",%02.4f" % np.median(tmp))
f.write("\n")
if __name__ == '__main__':
collect_cv_niftis('./', './cv_niftis')
evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions())
================================================
FILE: unetr_pp/evaluation/surface_dice.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from medpy.metric.binary import __surface_distances
def normalized_surface_dice(a: np.ndarray, b: np.ndarray, threshold: float, spacing: tuple = None, connectivity=1):
"""
This implementation differs from the official surface dice implementation! These two are not comparable!!!!!
The normalized surface dice is symmetric, so it should not matter whether a or b is the reference image
This implementation natively supports 2D and 3D images. Whether other dimensions are supported depends on the
__surface_distances implementation in medpy
:param a: image 1, must have the same shape as b
:param b: image 2, must have the same shape as a
:param threshold: distances below this threshold will be counted as true positives. Threshold is in mm, not voxels!
(if spacing = (1, 1(, 1)) then one voxel=1mm so the threshold is effectively in voxels)
must be a tuple of len dimension(a)
:param spacing: how many mm is one voxel in reality? Can be left at None, we then assume an isotropic spacing of 1mm
:param connectivity: see scipy.ndimage.generate_binary_structure for more information. I suggest you leave that
one alone
:return:
"""
assert all([i == j for i, j in zip(a.shape, b.shape)]), "a and b must have the same shape. a.shape= %s, " \
"b.shape= %s" % (str(a.shape), str(b.shape))
if spacing is None:
spacing = tuple([1 for _ in range(len(a.shape))])
a_to_b = __surface_distances(a, b, spacing, connectivity)
b_to_a = __surface_distances(b, a, spacing, connectivity)
numel_a = len(a_to_b)
numel_b = len(b_to_a)
tp_a = np.sum(a_to_b <= threshold) / numel_a
tp_b = np.sum(b_to_a <= threshold) / numel_b
fp = np.sum(a_to_b > threshold) / numel_a
fn = np.sum(b_to_a > threshold) / numel_b
dc = (tp_a + tp_b) / (tp_a + tp_b + fp + fn + 1e-8) # 1e-8 just so that we don't get div by 0
return dc
================================================
FILE: unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/.gitignore
================================================
# Ignore everything in this directory
*
# Except this file
!.gitignore
================================================
FILE: unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/.gitignore
================================================
# Ignore everything in this directory
*
# Except this file
!.gitignore
================================================
FILE: unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/.gitignore
================================================
# Ignore everything in this directory
*
# Except this file
!.gitignore
================================================
FILE: unetr_pp/evaluation/unetr_pp_tumor_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/.gitignore
================================================
# Ignore everything in this directory
*
# Except this file
!.gitignore
================================================
FILE: unetr_pp/experiment_planning/DatasetAnalyzer.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from unetr_pp.configuration import default_num_threads
from unetr_pp.paths import nnFormer_raw_data, nnFormer_cropped_data
import numpy as np
import pickle
from unetr_pp.preprocessing.cropping import get_patient_identifiers_from_cropped_files
from skimage.morphology import label
from collections import OrderedDict
class DatasetAnalyzer(object):
def __init__(self, folder_with_cropped_data, overwrite=True, num_processes=default_num_threads):
"""
:param folder_with_cropped_data:
:param overwrite: If True then precomputed values will not be used and instead recomputed from the data.
False will allow loading of precomputed values. This may be dangerous though if some of the code of this class
was changed, therefore the default is True.
"""
self.num_processes = num_processes
self.overwrite = overwrite
self.folder_with_cropped_data = folder_with_cropped_data
self.sizes = self.spacings = None
self.patient_identifiers = get_patient_identifiers_from_cropped_files(self.folder_with_cropped_data)
assert isfile(join(self.folder_with_cropped_data, "dataset.json")), \
"dataset.json needs to be in folder_with_cropped_data"
self.props_per_case_file = join(self.folder_with_cropped_data, "props_per_case.pkl")
self.intensityproperties_file = join(self.folder_with_cropped_data, "intensityproperties.pkl")
def load_properties_of_cropped(self, case_identifier):
with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f:
properties = pickle.load(f)
return properties
@staticmethod
def _check_if_all_in_one_region(seg, regions):
res = OrderedDict()
for r in regions:
new_seg = np.zeros(seg.shape)
for c in r:
new_seg[seg == c] = 1
labelmap, numlabels = label(new_seg, return_num=True)
if numlabels != 1:
res[tuple(r)] = False
else:
res[tuple(r)] = True
return res
@staticmethod
def _collect_class_and_region_sizes(seg, all_classes, vol_per_voxel):
volume_per_class = OrderedDict()
region_volume_per_class = OrderedDict()
for c in all_classes:
region_volume_per_class[c] = []
volume_per_class[c] = np.sum(seg == c) * vol_per_voxel
labelmap, numregions = label(seg == c, return_num=True)
for l in range(1, numregions + 1):
region_volume_per_class[c].append(np.sum(labelmap == l) * vol_per_voxel)
return volume_per_class, region_volume_per_class
def _get_unique_labels(self, patient_identifier):
seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
unique_classes = np.unique(seg)
return unique_classes
def _load_seg_analyze_classes(self, patient_identifier, all_classes):
"""
1) what class is in this training case?
2) what is the size distribution for each class?
3) what is the region size of each class?
4) check if all in one region
:return:
"""
seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
pkl = load_pickle(join(self.folder_with_cropped_data, patient_identifier) + ".pkl")
vol_per_voxel = np.prod(pkl['itk_spacing'])
# ad 1)
unique_classes = np.unique(seg)
# 4) check if all in one region
regions = list()
regions.append(list(all_classes))
for c in all_classes:
regions.append((c, ))
all_in_one_region = self._check_if_all_in_one_region(seg, regions)
# 2 & 3) region sizes
volume_per_class, region_sizes = self._collect_class_and_region_sizes(seg, all_classes, vol_per_voxel)
return unique_classes, all_in_one_region, volume_per_class, region_sizes
def get_classes(self):
datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
return datasetjson['labels']
def analyse_segmentations(self):
class_dct = self.get_classes()
if self.overwrite or not isfile(self.props_per_case_file):
p = Pool(self.num_processes)
res = p.map(self._get_unique_labels, self.patient_identifiers)
p.close()
p.join()
props_per_patient = OrderedDict()
for p, unique_classes in \
zip(self.patient_identifiers, res):
props = dict()
props['has_classes'] = unique_classes
props_per_patient[p] = props
save_pickle(props_per_patient, self.props_per_case_file)
else:
props_per_patient = load_pickle(self.props_per_case_file)
return class_dct, props_per_patient
def get_sizes_and_spacings_after_cropping(self):
sizes = []
spacings = []
# for c in case_identifiers:
for c in self.patient_identifiers:
properties = self.load_properties_of_cropped(c)
sizes.append(properties["size_after_cropping"])
spacings.append(properties["original_spacing"])
return sizes, spacings
def get_modalities(self):
datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
modalities = datasetjson["modality"]
modalities = {int(k): modalities[k] for k in modalities.keys()}
return modalities
def get_size_reduction_by_cropping(self):
size_reduction = OrderedDict()
for p in self.patient_identifiers:
props = self.load_properties_of_cropped(p)
shape_before_crop = props["original_size_of_raw_data"]
shape_after_crop = props['size_after_cropping']
size_red = np.prod(shape_after_crop) / np.prod(shape_before_crop)
size_reduction[p] = size_red
return size_reduction
def _get_voxels_in_foreground(self, patient_identifier, modality_id):
all_data = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data']
modality = all_data[modality_id]
mask = all_data[-1] > 0
voxels = list(modality[mask][::10]) # no need to take every voxel
return voxels
@staticmethod
def _compute_stats(voxels):
if len(voxels) == 0:
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
median = np.median(voxels)
mean = np.mean(voxels)
sd = np.std(voxels)
mn = np.min(voxels)
mx = np.max(voxels)
percentile_99_5 = np.percentile(voxels, 99.5)
percentile_00_5 = np.percentile(voxels, 00.5)
return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5
def collect_intensity_properties(self, num_modalities):
if self.overwrite or not isfile(self.intensityproperties_file):
p = Pool(self.num_processes)
results = OrderedDict()
for mod_id in range(num_modalities):
results[mod_id] = OrderedDict()
v = p.starmap(self._get_voxels_in_foreground, zip(self.patient_identifiers,
[mod_id] * len(self.patient_identifiers)))
w = []
for iv in v:
w += iv
median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 = self._compute_stats(w)
local_props = p.map(self._compute_stats, v)
props_per_case = OrderedDict()
for i, pat in enumerate(self.patient_identifiers):
props_per_case[pat] = OrderedDict()
props_per_case[pat]['median'] = local_props[i][0]
props_per_case[pat]['mean'] = local_props[i][1]
props_per_case[pat]['sd'] = local_props[i][2]
props_per_case[pat]['mn'] = local_props[i][3]
props_per_case[pat]['mx'] = local_props[i][4]
props_per_case[pat]['percentile_99_5'] = local_props[i][5]
props_per_case[pat]['percentile_00_5'] = local_props[i][6]
results[mod_id]['local_props'] = props_per_case
results[mod_id]['median'] = median
results[mod_id]['mean'] = mean
results[mod_id]['sd'] = sd
results[mod_id]['mn'] = mn
results[mod_id]['mx'] = mx
results[mod_id]['percentile_99_5'] = percentile_99_5
results[mod_id]['percentile_00_5'] = percentile_00_5
p.close()
p.join()
save_pickle(results, self.intensityproperties_file)
else:
results = load_pickle(self.intensityproperties_file)
return results
def analyze_dataset(self, collect_intensityproperties=True):
# get all spacings and sizes
sizes, spacings = self.get_sizes_and_spacings_after_cropping()
# get all classes and what classes are in what patients
# class min size
# region size per class
classes = self.get_classes()
all_classes = [int(i) for i in classes.keys() if int(i) > 0]
# modalities
modalities = self.get_modalities()
# collect intensity information
if collect_intensityproperties:
intensityproperties = self.collect_intensity_properties(len(modalities))
else:
intensityproperties = None
# size reduction by cropping
size_reductions = self.get_size_reduction_by_cropping()
dataset_properties = dict()
dataset_properties['all_sizes'] = sizes
dataset_properties['all_spacings'] = spacings
dataset_properties['all_classes'] = all_classes
dataset_properties['modalities'] = modalities # {idx: modality name}
dataset_properties['intensityproperties'] = intensityproperties
dataset_properties['size_reductions'] = size_reductions # {patient_id: size_reduction}
save_pickle(dataset_properties, join(self.folder_with_cropped_data, "dataset_properties.pkl"))
return dataset_properties
================================================
FILE: unetr_pp/experiment_planning/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_11GB.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_11GB(ExperimentPlanner3D_v21):
"""
Same as ExperimentPlanner3D_v21, but designed to fill a RTX2080 ti (11GB) in fp16
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_11GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.1_big"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_big_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
We need to adapt ref
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
# use_this_for_batch_size_computation_3D = 520000000 # 505789440
# typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a 2080ti we have 11. Allow for more space
# to be used
ref = Generic_UNet.use_this_for_batch_size_computation_3D * 11 / 8
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
# print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_16GB.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_16GB(ExperimentPlanner3D_v21):
"""
Same as ExperimentPlanner3D_v21, but designed to fill 16GB in fp16
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.1_16GB"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_16GB_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
We need to adapt ref
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
# use_this_for_batch_size_computation_3D = 520000000 # 505789440
# typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a 2080ti we have 11. Allow for more space
# to be used
ref = Generic_UNet.use_this_for_batch_size_computation_3D * 16 / 8.5
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
# print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_32GB.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_32GB(ExperimentPlanner3D_v21):
"""
Same as ExperimentPlanner3D_v21, but designed to fill a V100 (32GB) in fp16
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_32GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.1_verybig"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_verybig_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
We need to adapt ref
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
# use_this_for_batch_size_computation_3D = 520000000 # 505789440
# typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a V100 we have 32. Allow for more space
# to be used
ref = Generic_UNet.use_this_for_batch_size_computation_3D * 32 / 8
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
# print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_3cps(ExperimentPlanner3D_v21):
"""
have 3x conv-in-lrelu per resolution instead of 2 while remaining in the same memory budget
This only works with 3d fullres because we use the same data as ExperimentPlanner3D_v21. Lowres would require to
rerun preprocesing (different patch size = different 3d lowres target spacing)
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_3cps, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_3cps_plans_3D.pkl")
self.unet_base_num_features = 32
self.conv_per_stage = 3
def run_preprocessing(self, num_threads):
pass
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.paths import *
class ExperimentPlanner3D_v22(ExperimentPlanner3D_v21):
"""
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super().__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.2"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.2_plans_3D.pkl")
def get_target_spacing(self):
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
#median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than self.anisotropy_thresholdxthe_other_axes
target_spacing_of_that_axis = max(max(other_spacings) * self.anisotropy_threshold, target_spacing_of_that_axis)
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.paths import *
class ExperimentPlanner3D_v23(ExperimentPlanner3D_v21):
"""
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v23, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.3"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.3_plans_3D.pkl")
self.preprocessor_name = "Preprocessor3DDifferentResampling"
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_residual_3DUNet_v21.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.paths import *
from unetr_pp.network_architecture.generic_modular_residual_UNet import FabiansUNet
class ExperimentPlanner3DFabiansResUNet_v21(ExperimentPlanner3D_v21):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3DFabiansResUNet_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.1"# "nnFormerData_FabiansResUNet_v2.1"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlans_FabiansResUNet_v2.1_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
We use FabiansUNet instead of Generic_UNet
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes
blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)]
blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1]
ref = FabiansUNet.use_this_for_3D_configuration
here = FabiansUNet.compute_approx_vram_consumption(input_patch_size, self.unet_base_num_features,
self.unet_max_num_filters, num_modalities, num_classes,
pool_op_kernel_sizes, blocks_per_stage_encoder,
blocks_per_stage_decoder, 2, self.unet_min_batch_size,)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes
blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)]
blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1]
here = FabiansUNet.compute_approx_vram_consumption(new_shp, self.unet_base_num_features,
self.unet_max_num_filters, num_modalities, num_classes,
pool_op_kernel_sizes, blocks_per_stage_encoder,
blocks_per_stage_decoder, 2, self.unet_min_batch_size)
input_patch_size = new_shp
batch_size = FabiansUNet.default_min_batch_size
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
'num_blocks_encoder': blocks_per_stage_encoder,
'num_blocks_decoder': blocks_per_stage_decoder
}
return plan
def run_preprocessing(self, num_threads):
"""
On all datasets except 3d fullres on spleen the preprocessed data would look identical to
ExperimentPlanner3D_v21 (I tested decathlon data only). Therefore we just reuse the preprocessed data of
that other planner
:param num_threads:
:return:
"""
pass
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet_v21 import ExperimentPlanner2D_v21
from unetr_pp.paths import *
class ExperimentPlanner2D_v21_RGB_scaleTo_0_1(ExperimentPlanner2D_v21):
"""
used by tutorial unetr_pp.tutorials.custom_preprocessing
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super().__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormer_RGB_scaleTo_0_1"
self.plans_fname = join(self.preprocessed_output_folder, "nnFormer_RGB_scaleTo_0_1" + "_plans_2D.pkl")
# The custom preprocessor class we intend to use is GenericPreprocessor_scale_uint8_to_0_1. It must be located
# in unetr_pp.preprocessing (any file and submodule) and will be found by its name. Make sure to always define
# unique names!
self.preprocessor_name = 'GenericPreprocessor_scale_uint8_to_0_1'
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.paths import *
class ExperimentPlannerCT2(ExperimentPlanner):
"""
preprocesses CT data with the "CT2" normalization.
(clip range comes from training set and is the 0.5 and 99.5 percentile of intensities in foreground)
CT = clip to range, then normalize with global mn and sd (computed on foreground in training set)
CT2 = clip to range, normalize each case separately with its own mn and std (computed within the area that was in clip_range)
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlannerCT2, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormer_CT2"
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "CT2_plans_3D.pkl")
def determine_normalization_scheme(self):
schemes = OrderedDict()
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
for i in range(num_modalities):
if modalities[i] == "CT":
schemes[i] = "CT2"
else:
schemes[i] = "nonCT"
return schemes
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.paths import *
class ExperimentPlannernonCT(ExperimentPlanner):
"""
Preprocesses all data in nonCT mode (this is what we use for MRI per default, but here it is applied to CT images
as well)
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlannernonCT, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormer_nonCT"
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "nonCT_plans_3D.pkl")
def determine_normalization_scheme(self):
schemes = OrderedDict()
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
for i in range(num_modalities):
if modalities[i] == "CT":
schemes[i] = "nonCT"
else:
schemes[i] = "nonCT"
return schemes
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_mm.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlannerIso(ExperimentPlanner):
"""
attempts to create patches that have an isotropic size (in mm, not voxels)
CAREFUL!
this one does not support transpose_forward and transpose_backward
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super().__init__(folder_with_cropped_data, preprocessed_output_folder)
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixedisoPatchesInmm_plans_3D.pkl")
self.data_identifier = "nnFormer_isoPatchesInmm"
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
ref = Generic_UNet.use_this_for_batch_size_computation_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
# here is the difference to ExperimentPlanner. In the old version we made the aspect ratio match
# between patch and new_median_shape, regardless of spacing. It could be better to enforce isotropy
# (in mm) instead
current_patch_in_mm = new_shp * current_spacing
axis_to_be_reduced = np.argsort(current_patch_in_mm)[-1]
# from here on it's the same as before
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props_poolLateV2(tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_voxels.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_IsoPatchesInVoxels(ExperimentPlanner):
"""
patches that are isotropic in the number of voxels (not mm), such as 128x128x128 allow more voxels to be processed
at once because we don't have to do annoying pooling stuff
CAREFUL!
this one does not support transpose_forward and transpose_backward
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_IsoPatchesInVoxels, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_isoPatchesInVoxels"
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixedisoPatchesInVoxels_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
input_patch_size = new_median_shape
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
ref = Generic_UNet.use_this_for_batch_size_computation_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
# find the largest axis. If patch is isotropic, pick the axis with the largest spacing
if len(np.unique(new_shp)) == 1:
axis_to_be_reduced = np.argsort(current_spacing)[-1]
else:
axis_to_be_reduced = np.argsort(new_shp)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props_poolLateV2(tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_allConv3x3.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlannerAllConv3x3(ExperimentPlanner):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlannerAllConv3x3, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlans" + "allConv3x3_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is
opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that
for some organ of interest the acquisition method will most likely be chosen such that the field of view and
voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated
for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I
will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to
try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch)
The patches created here attempt keep the aspect ratio of the new_median_shape
:param current_spacing:
:param original_spacing:
:param original_shape:
:param num_cases:
:return:
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
ref = Generic_UNet.use_this_for_batch_size_computation_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props_poolLateV2(tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
for s in range(len(conv_kernel_sizes)):
conv_kernel_sizes[s] = [3 for _ in conv_kernel_sizes[s]]
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
def run_preprocessing(self, num_threads):
pass
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlannerPoolBasedOnSpacing(ExperimentPlanner):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlannerPoolBasedOnSpacing, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_poolBasedOnSpacing"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlans" + "poolBasedOnSpacing_plans_3D.pkl")
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is
(2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing.
This can cause a larger memory footprint, so it can be beneficial to revise this.
Here we are pooling based on the spacing of the data.
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
ref = Generic_UNet.use_this_for_batch_size_computation_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.paths import *
class ExperimentPlannerTargetSpacingForAnisoAxis(ExperimentPlanner):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super().__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_targetSpacingForAnisoAxis"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlans" + "targetSpacingForAnisoAxis_plans_3D.pkl")
def get_target_spacing(self):
"""
per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
(for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a pacing of 5 or 6 mm in the low
resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
impact performance (due to the low number of slices).
"""
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < max(other_sizes)
# we don't use the last one for now
#median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_customTargetSpacing_2x2x2(ExperimentPlanner3D_v21):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
# we change the data identifier and plans_fname. This will make this experiment planner save the preprocessed
# data in a different folder so that they can co-exist with the default (ExperimentPlanner3D_v21). We also
# create a custom plans file that will be linked to this data
self.data_identifier = "nnFormerData_plans_v2.1_trgSp_2x2x2"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_trgSp_2x2x2_plans_3D.pkl")
def get_target_spacing(self):
# simply return the desired spacing as np.array
return np.array([2., 2., 2.]) # make sure this is float!!!! Not int!
================================================
FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_noResampling.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from unetr_pp.experiment_planning.alternative_experiment_planning.experiment_planner_baseline_3DUNet_v21_16GB import \
ExperimentPlanner3D_v21_16GB
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
ExperimentPlanner3D_v21
from unetr_pp.paths import *
class ExperimentPlanner3D_v21_noResampling(ExperimentPlanner3D_v21):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_noResampling, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_noRes_plans_v2.1"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_noRes_plans_3D.pkl")
self.preprocessor_name = "PreprocessorFor3D_NoResampling"
def plan_experiment(self):
"""
DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres
:return:
"""
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization)
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
all_classes = self.dataset_properties['all_classes']
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
target_spacing = self.get_target_spacing()
new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
self.transpose_forward = [max_spacing_axis] + remaining_axes
self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
# we base our calculations on the median shape of the datasets
median_shape = np.median(np.vstack(new_shapes), 0)
print("the median shape of the dataset is ", median_shape)
max_shape = np.max(np.vstack(new_shapes), 0)
print("the max shape in the dataset is ", max_shape)
min_shape = np.min(np.vstack(new_shapes), 0)
print("the min shape in the dataset is ", min_shape)
print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
# how many stages will the image pyramid have?
self.plans_per_stage = list()
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
print("the transposed median shape of the dataset is ", median_shape_transposed)
print("generating configuration for 3d_fullres")
self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
median_shape_transposed,
len(self.list_of_cropped_npz_files),
num_modalities, len(all_classes) + 1))
# thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-)
# if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
# architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0:
more = False
else:
more = True
if more:
pass
self.plans_per_stage = self.plans_per_stage[::-1]
self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
print(self.plans_per_stage)
print("transpose forward", self.transpose_forward)
print("transpose backward", self.transpose_backward)
normalization_schemes = self.determine_normalization_scheme()
only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
# removed training data based postprocessing. This is deprecated
# these are independent of the stage
plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
'modalities': modalities, 'normalization_schemes': normalization_schemes,
'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
'original_spacings': spacings, 'original_sizes': sizes,
'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'keep_only_largest_region': only_keep_largest_connected_component,
'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
'preprocessor_name': self.preprocessor_name,
'conv_per_stage': self.conv_per_stage,
}
self.plans = plans
self.save_my_plans()
class ExperimentPlanner3D_v21_noResampling_16GB(ExperimentPlanner3D_v21_16GB):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21_noResampling_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_noRes_plans_16GB_v2.1"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_noRes_16GB_plans_3D.pkl")
self.preprocessor_name = "PreprocessorFor3D_NoResampling"
def plan_experiment(self):
"""
DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres
:return:
"""
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization)
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
all_classes = self.dataset_properties['all_classes']
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
target_spacing = self.get_target_spacing()
new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
self.transpose_forward = [max_spacing_axis] + remaining_axes
self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
# we base our calculations on the median shape of the datasets
median_shape = np.median(np.vstack(new_shapes), 0)
print("the median shape of the dataset is ", median_shape)
max_shape = np.max(np.vstack(new_shapes), 0)
print("the max shape in the dataset is ", max_shape)
min_shape = np.min(np.vstack(new_shapes), 0)
print("the min shape in the dataset is ", min_shape)
print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
# how many stages will the image pyramid have?
self.plans_per_stage = list()
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
print("the transposed median shape of the dataset is ", median_shape_transposed)
print("generating configuration for 3d_fullres")
self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
median_shape_transposed,
len(self.list_of_cropped_npz_files),
num_modalities, len(all_classes) + 1))
# thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-)
# if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
# architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0:
more = False
else:
more = True
if more:
pass
self.plans_per_stage = self.plans_per_stage[::-1]
self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
print(self.plans_per_stage)
print("transpose forward", self.transpose_forward)
print("transpose backward", self.transpose_backward)
normalization_schemes = self.determine_normalization_scheme()
only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
# removed training data based postprocessing. This is deprecated
# these are independent of the stage
plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
'modalities': modalities, 'normalization_schemes': normalization_schemes,
'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
'original_spacings': spacings, 'original_sizes': sizes,
'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'keep_only_largest_region': only_keep_largest_connected_component,
'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
'preprocessor_name': self.preprocessor_name,
'conv_per_stage': self.conv_per_stage,
}
self.plans = plans
self.save_my_plans()
================================================
FILE: unetr_pp/experiment_planning/change_batch_size.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
if __name__ == '__main__':
input_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl'
output_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl'
a = load_pickle(input_file)
#a['plans_per_stage'][0]['batch_size'] = int(np.floor(6 / 9 * a['plans_per_stage'][0]['batch_size']))
a['plans_per_stage'][0]['batch_size'] = 4
save_pickle(a, output_file)
================================================
FILE: unetr_pp/experiment_planning/common_utils.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from copy import deepcopy
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
import SimpleITK as sitk
import shutil
from batchgenerators.utilities.file_and_folder_operations import join
def split_4d_nifti(filename, output_folder):
img_itk = sitk.ReadImage(filename)
dim = img_itk.GetDimension()
file_base = filename.split("/")[-1]
if dim == 3:
shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
return
elif dim != 4:
raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
else:
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing()
origin = img_itk.GetOrigin()
direction = np.array(img_itk.GetDirection()).reshape(4,4)
# now modify these to remove the fourth dimension
spacing = tuple(list(spacing[:-1]))
origin = tuple(list(origin[:-1]))
direction = tuple(direction[:-1, :-1].reshape(-1))
for i, t in enumerate(range(img_npy.shape[0])):
img = img_npy[t]
img_itk_new = sitk.GetImageFromArray(img)
img_itk_new.SetSpacing(spacing)
img_itk_new.SetOrigin(origin)
img_itk_new.SetDirection(direction)
sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))
def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing):
"""
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
:return:
"""
initial_spacing = deepcopy(spacing)
reach = max(initial_spacing)
dim = len(patch_size)
num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size)
net_num_pool_op_kernel_sizes = []
net_conv_kernel_sizes = []
net_numpool = max(num_pool_per_axis)
current_spacing = spacing
for p in range(net_numpool):
reached = [current_spacing[i] / reach > 0.5 for i in range(dim)]
pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)]
if all(reached):
conv = [3] * dim
else:
conv = [3 if not reached[i] else 1 for i in range(dim)]
net_num_pool_op_kernel_sizes.append(pool)
net_conv_kernel_sizes.append(conv)
current_spacing = [i * j for i, j in zip(current_spacing, pool)]
net_conv_kernel_sizes.append([3] * dim)
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
return num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by
def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
"""
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
:return:
"""
dim = len(spacing)
current_spacing = deepcopy(list(spacing))
current_size = deepcopy(list(patch_size))
pool_op_kernel_sizes = []
conv_kernel_sizes = []
num_pool_per_axis = [0] * dim
while True:
# This is a problem because sometimes we have spacing 20, 50, 50 and we want to still keep pooling.
# Here we would stop however. This is not what we want! Fixed in get_pool_and_conv_propsv2
min_spacing = min(current_spacing)
valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2]
axes = []
for a in range(dim):
my_spacing = current_spacing[a]
partners = [i for i in range(dim) if current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2]
if len(partners) > len(axes):
axes = partners
conv_kernel_size = [3 if i in axes else 1 for i in range(dim)]
# exclude axes that we cannot pool further because of min_feature_map_size constraint
#before = len(valid_axes_for_pool)
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2*min_feature_map_size]
#after = len(valid_axes_for_pool)
#if after == 1 and before > 1:
# break
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
if len(valid_axes_for_pool) == 0:
break
#print(current_spacing, current_size)
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
pool_kernel_sizes = [0] * dim
for v in valid_axes_for_pool:
pool_kernel_sizes[v] = 2
num_pool_per_axis[v] += 1
current_spacing[v] *= 2
current_size[v] = np.ceil(current_size[v] / 2)
for nv in other_axes:
pool_kernel_sizes[nv] = 1
pool_op_kernel_sizes.append(pool_kernel_sizes)
conv_kernel_sizes.append(conv_kernel_size)
#print(conv_kernel_sizes)
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
conv_kernel_sizes.append([3]*dim)
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
def get_pool_and_conv_props_v2(spacing, patch_size, min_feature_map_size, max_numpool):
"""
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
:return:
"""
dim = len(spacing)
current_spacing = deepcopy(list(spacing))
current_size = deepcopy(list(patch_size))
pool_op_kernel_sizes = []
conv_kernel_sizes = []
num_pool_per_axis = [0] * dim
kernel_size = [1] * dim
while True:
# exclude axes that we cannot pool further because of min_feature_map_size constraint
valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
if len(valid_axes_for_pool) < 1:
break
spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
# find axis that are within factor of 2 within smallest spacing
min_spacing_of_valid = min(spacings_of_axes)
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
# max_numpool constraint
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
if len(valid_axes_for_pool) == 1:
if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
pass
else:
break
if len(valid_axes_for_pool) < 1:
break
# now we need to find kernel sizes
# kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
# factor 2 of min_spacing. Once they are 3 they remain 3
for d in range(dim):
if kernel_size[d] == 3:
continue
else:
if spacings_of_axes[d] / min(current_spacing) < 2:
kernel_size[d] = 3
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
pool_kernel_sizes = [0] * dim
for v in valid_axes_for_pool:
pool_kernel_sizes[v] = 2
num_pool_per_axis[v] += 1
current_spacing[v] *= 2
current_size[v] = np.ceil(current_size[v] / 2)
for nv in other_axes:
pool_kernel_sizes[nv] = 1
pool_op_kernel_sizes.append(pool_kernel_sizes)
conv_kernel_sizes.append(deepcopy(kernel_size))
#print(conv_kernel_sizes)
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
conv_kernel_sizes.append([3]*dim)
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
def get_shape_must_be_divisible_by(net_numpool_per_axis):
return 2 ** np.array(net_numpool_per_axis)
def pad_shape(shape, must_be_divisible_by):
"""
pads shape so that it is divisibly by must_be_divisible_by
:param shape:
:param must_be_divisible_by:
:return:
"""
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
must_be_divisible_by = [must_be_divisible_by] * len(shape)
else:
assert len(must_be_divisible_by) == len(shape)
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
for i in range(len(shape)):
if shape[i] % must_be_divisible_by[i] == 0:
new_shp[i] -= must_be_divisible_by[i]
new_shp = np.array(new_shp).astype(int)
return new_shp
def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4):
network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int)
network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis]
return network_numpool_per_axis
if __name__ == '__main__':
# trying to fix https://github.com/MIC-DKFZ/nnFormer/issues/261
median_shape = [24, 504, 512]
spacing = [5.9999094, 0.50781202, 0.50781202]
num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by = get_pool_and_conv_props_poolLateV2(median_shape, min_feature_map_size=4, max_numpool=999, spacing=spacing)
================================================
FILE: unetr_pp/experiment_planning/experiment_planner_baseline_2DUNet.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import unetr_pp
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import load_pickle, subfiles
from multiprocessing.pool import Pool
from unetr_pp.configuration import default_num_threads
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.experiment_planning.utils import add_classes_in_slice_info
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
from unetr_pp.preprocessing.preprocessing import PreprocessorFor2D
from unetr_pp.training.model_restore import recursive_find_python_class
class ExperimentPlanner2D(ExperimentPlanner):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner2D, self).__init__(folder_with_cropped_data,
preprocessed_output_folder)
self.data_identifier = default_data_identifier + "_2D"
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "_plans_2D.pkl")
self.unet_base_num_features = 30
self.unet_max_num_filters = 512
self.unet_max_numpool = 999
self.preprocessor_name = "PreprocessorFor2D"
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases
input_patch_size = new_median_shape[1:]
network_numpool, net_pool_kernel_sizes, net_conv_kernel_sizes, input_patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
estimated_gpu_ram_consumption = Generic_UNet.compute_approx_vram_consumption(input_patch_size,
network_numpool,
self.unet_base_num_features,
self.unet_max_num_filters,
num_modalities, num_classes,
net_pool_kernel_sizes,
conv_per_stage=self.conv_per_stage)
batch_size = int(np.floor(Generic_UNet.use_this_for_batch_size_computation_2D /
estimated_gpu_ram_consumption * Generic_UNet.DEFAULT_BATCH_SIZE_2D))
if batch_size < self.unet_min_batch_size:
raise RuntimeError("This framework is not made to process patches this large. We will add patch-based "
"2D networks later. Sorry for the inconvenience")
# check if batch size is too large (more than 5 % of dataset)
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
batch_size = max(1, min(batch_size, max_batch_size))
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_numpool,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'pool_op_kernel_sizes': net_pool_kernel_sizes,
'conv_kernel_sizes': net_conv_kernel_sizes,
'do_dummy_2D_data_aug': False
}
return plan
def plan_experiment(self):
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
print("Are we using the nonzero maks for normalizaion?", use_nonzero_mask_for_normalization)
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
all_classes = self.dataset_properties['all_classes']
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
target_spacing = self.get_target_spacing()
new_shapes = np.array([np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)])
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
self.transpose_forward = [max_spacing_axis] + remaining_axes
self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
# we base our calculations on the median shape of the datasets
median_shape = np.median(np.vstack(new_shapes), 0)
print("the median shape of the dataset is ", median_shape)
max_shape = np.max(np.vstack(new_shapes), 0)
print("the max shape in the dataset is ", max_shape)
min_shape = np.min(np.vstack(new_shapes), 0)
print("the min shape in the dataset is ", min_shape)
print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
# how many stages will the image pyramid have?
self.plans_per_stage = []
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
print("the transposed median shape of the dataset is ", median_shape_transposed)
self.plans_per_stage.append(
self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed,
num_cases=len(self.list_of_cropped_npz_files),
num_modalities=num_modalities,
num_classes=len(all_classes) + 1),
)
print(self.plans_per_stage)
self.plans_per_stage = self.plans_per_stage[::-1]
self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
normalization_schemes = self.determine_normalization_scheme()
# deprecated
only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
# these are independent of the stage
plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
'modalities': modalities, 'normalization_schemes': normalization_schemes,
'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
'original_spacings': spacings, 'original_sizes': sizes,
'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'keep_only_largest_region': only_keep_largest_connected_component,
'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
'preprocessor_name': self.preprocessor_name,
}
self.plans = plans
self.save_my_plans()
================================================
FILE: unetr_pp/experiment_planning/experiment_planner_baseline_2DUNet_v21.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
import numpy as np
class ExperimentPlanner2D_v21(ExperimentPlanner2D):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner2D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "nnFormerData_plans_v2.1_2D"
self.plans_fname = join(self.preprocessed_output_folder,
"nnFormerPlansv2.1_plans_2D.pkl")
self.unet_base_num_features = 32
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases
input_patch_size = new_median_shape[1:]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
# we pretend to use 30 feature maps. This will yield the same configuration as in V1. The larger memory
# footpring of 32 vs 30 is mor ethan offset by the fp16 training. We make fp16 training default
# Reason for 32 vs 30 feature maps is that 32 is faster in fp16 training (because multiple of 8)
ref = Generic_UNet.use_this_for_batch_size_computation_2D * Generic_UNet.DEFAULT_BATCH_SIZE_2D / 2 # for batch size 2
here = Generic_UNet.compute_approx_vram_consumption(new_shp,
network_num_pool_per_axis,
30,
self.unet_max_num_filters,
num_modalities, num_classes,
pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape[1:])[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing[1:], tmp, self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
# print(new_shp)
batch_size = int(np.floor(ref / here) * 2)
input_patch_size = new_shp
if batch_size < self.unet_min_batch_size:
raise RuntimeError("This should not happen")
# check if batch size is too large (more than 5 % of dataset)
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
batch_size = max(1, min(batch_size, max_batch_size))
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
'do_dummy_2D_data_aug': False
}
return plan
================================================
FILE: unetr_pp/experiment_planning/experiment_planner_baseline_3DUNet.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections import OrderedDict
from copy import deepcopy
import unetr_pp
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
from unetr_pp.experiment_planning.utils import create_lists_from_splitted_dataset
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
from unetr_pp.preprocessing.cropping import get_case_identifier_from_npz
from unetr_pp.training.model_restore import recursive_find_python_class
class ExperimentPlanner(object):
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
self.folder_with_cropped_data = folder_with_cropped_data
self.preprocessed_output_folder = preprocessed_output_folder
self.list_of_cropped_npz_files = subfiles(self.folder_with_cropped_data, True, None, ".npz", True)
self.preprocessor_name = "GenericPreprocessor"
assert isfile(join(self.folder_with_cropped_data, "dataset_properties.pkl")), \
"folder_with_cropped_data must contain dataset_properties.pkl"
self.dataset_properties = load_pickle(join(self.folder_with_cropped_data, "dataset_properties.pkl"))
self.plans_per_stage = OrderedDict()
self.plans = OrderedDict()
self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixed_plans_3D.pkl")
self.data_identifier = default_data_identifier
self.transpose_forward = [0, 1, 2]
self.transpose_backward = [0, 1, 2]
self.unet_base_num_features = Generic_UNet.BASE_NUM_FEATURES_3D
self.unet_max_num_filters = 320
self.unet_max_numpool = 999
self.unet_min_batch_size = 2
self.unet_featuremap_min_edge_length = 4
self.target_spacing_percentile = 50
self.anisotropy_threshold = 3
self.how_much_of_a_patient_must_the_network_see_at_stage0 = 4 # 1/4 of a patient
self.batch_size_covers_max_percent_of_dataset = 0.05 # all samples in the batch together cannot cover more
# than 5% of the entire dataset
self.conv_per_stage = 2
def get_target_spacing(self):
spacings = self.dataset_properties['all_spacings']
# target = np.median(np.vstack(spacings), 0)
# if target spacing is very anisotropic we may want to not downsample the axis with the worst spacing
# uncomment after mystery task submission
"""worst_spacing_axis = np.argmax(target)
if max(target) > (2.5 * min(target)):
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 5)
target[worst_spacing_axis] = target_spacing_of_that_axis"""
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
return target
def save_my_plans(self):
with open(self.plans_fname, 'wb') as f:
pickle.dump(self.plans, f)
def load_my_plans(self):
self.plans = load_pickle(self.plans_fname)
self.plans_per_stage = self.plans['plans_per_stage']
self.dataset_properties = self.plans['dataset_properties']
self.transpose_forward = self.plans['transpose_forward']
self.transpose_backward = self.plans['transpose_backward']
def determine_postprocessing(self):
pass
"""
Spoiler: This is unused, postprocessing was removed. Ignore it.
:return:
print("determining postprocessing...")
props_per_patient = self.dataset_properties['segmentation_props_per_patient']
all_region_keys = [i for k in props_per_patient.keys() for i in props_per_patient[k]['only_one_region'].keys()]
all_region_keys = list(set(all_region_keys))
only_keep_largest_connected_component = OrderedDict()
for r in all_region_keys:
all_results = [props_per_patient[k]['only_one_region'][r] for k in props_per_patient.keys()]
only_keep_largest_connected_component[tuple(r)] = all(all_results)
print("Postprocessing: only_keep_largest_connected_component", only_keep_largest_connected_component)
all_classes = self.dataset_properties['all_classes']
classes = [i for i in all_classes if i > 0]
props_per_patient = self.dataset_properties['segmentation_props_per_patient']
min_size_per_class = OrderedDict()
for c in classes:
all_num_voxels = []
for k in props_per_patient.keys():
all_num_voxels.append(props_per_patient[k]['volume_per_class'][c])
if len(all_num_voxels) > 0:
min_size_per_class[c] = np.percentile(all_num_voxels, 1) * MIN_SIZE_PER_CLASS_FACTOR
else:
min_size_per_class[c] = np.inf
min_region_size_per_class = OrderedDict()
for c in classes:
region_sizes = [l for k in props_per_patient for l in props_per_patient[k]['region_volume_per_class'][c]]
if len(region_sizes) > 0:
min_region_size_per_class[c] = min(region_sizes)
# we don't need that line but better safe than sorry, right?
min_region_size_per_class[c] = min(min_region_size_per_class[c], min_size_per_class[c])
else:
min_region_size_per_class[c] = 0
print("Postprocessing: min_size_per_class", min_size_per_class)
print("Postprocessing: min_region_size_per_class", min_region_size_per_class)
return only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class
"""
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is
opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that
for some organ of interest the acquisition method will most likely be chosen such that the field of view and
voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated
for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I
will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to
try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch)
The patches created here attempt keep the aspect ratio of the new_median_shape
:param current_spacing:
:param original_spacing:
:param original_shape:
:param num_cases:
:return:
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
ref = Generic_UNet.use_this_for_batch_size_computation_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props_poolLateV2(tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
current_spacing)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
# print(new_shp)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
def plan_experiment(self):
use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization)
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
all_classes = self.dataset_properties['all_classes']
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
target_spacing = self.get_target_spacing()
new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
self.transpose_forward = [max_spacing_axis] + remaining_axes
self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
# we base our calculations on the median shape of the datasets
median_shape = np.median(np.vstack(new_shapes), 0)
print("the median shape of the dataset is ", median_shape)
max_shape = np.max(np.vstack(new_shapes), 0)
print("the max shape in the dataset is ", max_shape)
min_shape = np.min(np.vstack(new_shapes), 0)
print("the min shape in the dataset is ", min_shape)
print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
# how many stages will the image pyramid have?
self.plans_per_stage = list()
target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
median_shape_transposed = np.array(median_shape)[self.transpose_forward]
print("the transposed median shape of the dataset is ", median_shape_transposed)
print("generating configuration for 3d_fullres")
self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
median_shape_transposed,
len(self.list_of_cropped_npz_files),
num_modalities, len(all_classes) + 1))
# thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-)
# if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
# architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
if np.prod(median_shape) / architecture_input_voxels_here < \
self.how_much_of_a_patient_must_the_network_see_at_stage0:
more = False
else:
more = True
if more:
print("generating configuration for 3d_lowres")
# if we are doing more than one stage then we want the lowest stage to have exactly
# HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0 (this is 4 by default so the number of voxels in the
# median shape of the lowest stage must be 4 times as much as the network can process at once (128x128x128 by
# default). Problem is that we are downsampling higher resolution axes before we start downsampling the
# out-of-plane axis. We could probably/maybe do this analytically but I am lazy, so here
# we do it the dumb way
lowres_stage_spacing = deepcopy(target_spacing)
num_voxels = np.prod(median_shape, dtype=np.float64)
while num_voxels > self.how_much_of_a_patient_must_the_network_see_at_stage0 * architecture_input_voxels_here:
max_spacing = max(lowres_stage_spacing)
if np.any((max_spacing / lowres_stage_spacing) > 2):
lowres_stage_spacing[(max_spacing / lowres_stage_spacing) > 2] \
*= 1.01
else:
lowres_stage_spacing *= 1.01
num_voxels = np.prod(target_spacing / lowres_stage_spacing * median_shape, dtype=np.float64)
lowres_stage_spacing_transposed = np.array(lowres_stage_spacing)[self.transpose_forward]
new = self.get_properties_for_stage(lowres_stage_spacing_transposed, target_spacing_transposed,
median_shape_transposed,
len(self.list_of_cropped_npz_files),
num_modalities, len(all_classes) + 1)
architecture_input_voxels_here = np.prod(new['patch_size'], dtype=np.int64)
if 2 * np.prod(new['median_patient_size_in_voxels'], dtype=np.int64) < np.prod(
self.plans_per_stage[0]['median_patient_size_in_voxels'], dtype=np.int64):
self.plans_per_stage.append(new)
self.plans_per_stage = self.plans_per_stage[::-1]
self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
print(self.plans_per_stage)
print("transpose forward", self.transpose_forward)
print("transpose backward", self.transpose_backward)
normalization_schemes = self.determine_normalization_scheme()
only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
# removed training data based postprocessing. This is deprecated
# these are independent of the stage
plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
'modalities': modalities, 'normalization_schemes': normalization_schemes,
'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
'original_spacings': spacings, 'original_sizes': sizes,
'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
'use_mask_for_norm': use_nonzero_mask_for_normalization,
'keep_only_largest_region': only_keep_largest_connected_component,
'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
'preprocessor_name': self.preprocessor_name,
'conv_per_stage': self.conv_per_stage,
}
self.plans = plans
self.save_my_plans()
def determine_normalization_scheme(self):
schemes = OrderedDict()
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
for i in range(num_modalities):
if modalities[i] == "CT" or modalities[i] == 'ct':
schemes[i] = "CT"
else:
schemes[i] = "nonCT"
return schemes
def save_properties_of_cropped(self, case_identifier, properties):
with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
def load_properties_of_cropped(self, case_identifier):
with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f:
properties = pickle.load(f)
return properties
def determine_whether_to_use_mask_for_norm(self):
# only use the nonzero mask for normalization of the cropping based on it resulted in a decrease in
# image size (this is an indication that the data is something like brats/isles and then we want to
# normalize in the brain region only)
modalities = self.dataset_properties['modalities']
num_modalities = len(list(modalities.keys()))
use_nonzero_mask_for_norm = OrderedDict()
for i in range(num_modalities):
if "CT" in modalities[i]:
use_nonzero_mask_for_norm[i] = False
else:
all_size_reductions = []
for k in self.dataset_properties['size_reductions'].keys():
all_size_reductions.append(self.dataset_properties['size_reductions'][k])
if np.median(all_size_reductions) < 3 / 4.:
print("using nonzero mask for normalization")
use_nonzero_mask_for_norm[i] = True
else:
print("not using nonzero mask for normalization")
use_nonzero_mask_for_norm[i] = False
for c in self.list_of_cropped_npz_files:
case_identifier = get_case_identifier_from_npz(c)
properties = self.load_properties_of_cropped(case_identifier)
properties['use_nonzero_mask_for_norm'] = use_nonzero_mask_for_norm
self.save_properties_of_cropped(case_identifier, properties)
use_nonzero_mask_for_normalization = use_nonzero_mask_for_norm
return use_nonzero_mask_for_normalization
def write_normalization_scheme_to_patients(self):
"""
This is used for test set preprocessing
:return:
"""
for c in self.list_of_cropped_npz_files:
case_identifier = get_case_identifier_from_npz(c)
properties = self.load_properties_of_cropped(case_identifier)
properties['use_nonzero_mask_for_norm'] = self.plans['use_mask_for_norm']
self.save_properties_of_cropped(case_identifier, properties)
def run_preprocessing(self, num_threads):
if os.path.isdir(join(self.preprocessed_output_folder, "gt_segmentations")):
shutil.rmtree(join(self.preprocessed_output_folder, "gt_segmentations"))
shutil.copytree(join(self.folder_with_cropped_data, "gt_segmentations"),
join(self.preprocessed_output_folder, "gt_segmentations"))
normalization_schemes = self.plans['normalization_schemes']
use_nonzero_mask_for_normalization = self.plans['use_mask_for_norm']
intensityproperties = self.plans['dataset_properties']['intensityproperties']
preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")],
self.preprocessor_name, current_module="unetr_pp.preprocessing")
assert preprocessor_class is not None
preprocessor = preprocessor_class(normalization_schemes, use_nonzero_mask_for_normalization,
self.transpose_forward,
intensityproperties)
target_spacings = [i["current_spacing"] for i in self.plans_per_stage.values()]
if self.plans['num_stages'] > 1 and not isinstance(num_threads, (list, tuple)):
num_threads = (default_num_threads, num_threads)
elif self.plans['num_stages'] == 1 and isinstance(num_threads, (list, tuple)):
num_threads = num_threads[-1]
preprocessor.run(target_spacings, self.folder_with_cropped_data, self.preprocessed_output_folder,
self.plans['data_identifier'], num_threads)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--task_ids", nargs="+", help="list of int")
parser.add_argument("-p", action="store_true", help="set this if you actually want to run the preprocessing. If "
"this is not set then this script will only create the plans file")
parser.add_argument("-tl", type=int, required=False, default=8, help="num_threads_lowres")
parser.add_argument("-tf", type=int, required=False, default=8, help="num_threads_fullres")
args = parser.parse_args()
task_ids = args.task_ids
run_preprocessing = args.p
tl = args.tl
tf = args.tf
tasks = []
for i in task_ids:
i = int(i)
candidates = subdirs(nnFormer_cropped_data, prefix="Task%03.0d" % i, join=False)
assert len(candidates) == 1
tasks.append(candidates[0])
for t in tasks:
try:
print("\n\n\n", t)
cropped_out_dir = os.path.join(nnFormer_cropped_data, t)
preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t)
splitted_4d_output_dir_task = os.path.join(nnFormer_raw_data, t)
lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False)
_ = dataset_analyzer.analyze_dataset() # this will write output files that will be used by the ExperimentPlanner
maybe_mkdir_p(preprocessing_output_dir_this_task)
shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task)
shutil.copy(join(nnFormer_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task)
threads = (tl, tf)
print("number of threads: ", threads, "\n")
exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task)
exp_planner.plan_experiment()
if run_preprocessing:
exp_planner.run_preprocessing(threads)
except Exception as e:
print(e)
================================================
FILE: unetr_pp/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.paths import *
class ExperimentPlanner3D_v21(ExperimentPlanner):
"""
Combines ExperimentPlannerPoolBasedOnSpacing and ExperimentPlannerTargetSpacingForAnisoAxis
We also increase the base_num_features to 32. This is solely because mixed precision training with 3D convs and
amp is A LOT faster if the number of filters is divisible by 8
"""
def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
self.data_identifier = "unetr_pp_Data_plans_v2.1"
self.plans_fname = join(self.preprocessed_output_folder,
"unetr_pp_Plansv2.1_plans_3D.pkl")
self.unet_base_num_features = 32
def get_target_spacing(self):
"""
per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
(for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low
resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
impact performance (due to the low number of slices).
"""
spacings = self.dataset_properties['all_spacings']
sizes = self.dataset_properties['all_sizes']
target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
# This should be used to determine the new median shape. The old implementation is not 100% correct.
# Fixed in 2.4
# sizes = [np.array(i) / target * np.array(j) for i, j in zip(spacings, sizes)]
target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
target_size_mm = np.array(target) * np.array(target_size)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
# we don't use the last one for now
#median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < max(other_spacings):
target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
num_modalities, num_classes):
"""
ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is
(2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing.
This can cause a larger memory footprint, so it can be beneficial to revise this.
Here we are pooling based on the spacing of the data.
"""
new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
dataset_num_voxels = np.prod(new_median_shape) * num_cases
# the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
# input_patch_size = new_median_shape
# compute how many voxels are one mm
input_patch_size = 1 / np.array(current_spacing)
# normalize voxels per mm
input_patch_size /= input_patch_size.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(int)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool)
# we compute as if we were using only 30 feature maps. We can do that because fp16 training is the standard
# now. That frees up some space. The decision to go with 32 is solely due to the speedup we get (non-multiples
# of 8 are not supported in nvidia amp)
ref = Generic_UNet.use_this_for_batch_size_computation_3D * self.unet_base_num_features / \
Generic_UNet.BASE_NUM_FEATURES_3D
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes,
pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
while here > ref:
axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
tmp = deepcopy(new_shp)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by_new = \
get_pool_and_conv_props(current_spacing, tmp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
# we have to recompute numpool now:
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
self.unet_featuremap_min_edge_length,
self.unet_max_numpool,
)
here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
self.unet_base_num_features,
self.unet_max_num_filters, num_modalities,
num_classes, pool_op_kernel_sizes,
conv_per_stage=self.conv_per_stage)
#print(new_shp)
#print(here, ref)
input_patch_size = new_shp
batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
batch_size = int(np.floor(max(ref / here, 1) * batch_size))
# check if batch size is too large
max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
np.prod(input_patch_size, dtype=np.int64)).astype(int)
max_batch_size = max(max_batch_size, self.unet_min_batch_size)
batch_size = max(1, min(batch_size, max_batch_size))
do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
0]) > self.anisotropy_threshold
plan = {
'batch_size': batch_size,
'num_pool_per_axis': network_num_pool_per_axis,
'patch_size': input_patch_size,
'median_patient_size_in_voxels': new_median_shape,
'current_spacing': current_spacing,
'original_spacing': original_spacing,
'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
'pool_op_kernel_sizes': pool_op_kernel_sizes,
'conv_kernel_sizes': conv_kernel_sizes,
}
return plan
================================================
FILE: unetr_pp/experiment_planning/nnFormer_convert_decathlon_task.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.experiment_planning.utils import split_4d
from unetr_pp.utilities.file_endings import remove_trailing_slash
def crawl_and_remove_hidden_from_decathlon(folder):
folder = remove_trailing_slash(folder)
assert folder.split('/')[-1].startswith("Task"), "This does not seem to be a decathlon folder. Please give me a " \
"folder that starts with TaskXX and has the subfolders imagesTr, " \
"labelsTr and imagesTs"
subf = subfolders(folder, join=False)
assert 'imagesTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
"folder that starts with TaskXX and has the subfolders imagesTr, " \
"labelsTr and imagesTs"
assert 'imagesTs' in subf, "This does not seem to be a decathlon folder. Please give me a " \
"folder that starts with TaskXX and has the subfolders imagesTr, " \
"labelsTr and imagesTs"
assert 'labelsTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
"folder that starts with TaskXX and has the subfolders imagesTr, " \
"labelsTr and imagesTs"
_ = [os.remove(i) for i in subfiles(folder, prefix=".")]
_ = [os.remove(i) for i in subfiles(join(folder, 'imagesTr'), prefix=".")]
_ = [os.remove(i) for i in subfiles(join(folder, 'labelsTr'), prefix=".")]
_ = [os.remove(i) for i in subfiles(join(folder, 'imagesTs'), prefix=".")]
def main():
import argparse
parser = argparse.ArgumentParser(description="The MSD provides data as 4D Niftis with the modality being the first"
" dimension. We think this may be cumbersome for some users and "
"therefore expect 3D niftixs instead, with one file per modality. "
"This utility will convert 4D MSD data into the format nnU-Net "
"expects")
parser.add_argument("-i", help="Input folder. Must point to a TaskXX_TASKNAME folder as downloaded from the MSD "
"website", required=False,
default="/home/maaz/PycharmProjects/nnFormer/DATASET/nnFormer_raw/nnFormer_raw_data"
"/Task02_Synapse")
parser.add_argument("-p", required=False, default=default_num_threads, type=int,
help="Use this to specify how many processes are used to run the script. "
"Default is %d" % default_num_threads)
parser.add_argument("-output_task_id", required=False, default=None, type=int,
help="If specified, this will overwrite the task id in the output folder. If unspecified, the "
"task id of the input folder will be used.")
args = parser.parse_args()
crawl_and_remove_hidden_from_decathlon(args.i)
split_4d(args.i, args.p, args.output_task_id)
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/experiment_planning/nnFormer_plan_and_preprocess.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unetr_pp
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
from unetr_pp.experiment_planning.utils import crop
from unetr_pp.paths import *
import shutil
from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name
from unetr_pp.preprocessing.sanity_checks import verify_dataset_integrity
from unetr_pp.training.model_restore import recursive_find_python_class
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--task_ids", nargs="+", help="List of integers belonging to the task ids you wish to run"
" experiment planning and preprocessing for. Each of these "
"ids must, have a matching folder 'TaskXXX_' in the raw "
"data folder")
parser.add_argument("-pl3d", "--planner3d", type=str, default="ExperimentPlanner3D_v21",
help="Name of the ExperimentPlanner class for the full resolution 3D U-Net and U-Net cascade. "
"Default is ExperimentPlanner3D_v21. Can be 'None', in which case these U-Nets will not be "
"configured")
parser.add_argument("-pl2d", "--planner2d", type=str, default="ExperimentPlanner2D_v21",
help="Name of the ExperimentPlanner class for the 2D U-Net. Default is ExperimentPlanner2D_v21. "
"Can be 'None', in which case this U-Net will not be configured")
parser.add_argument("-no_pp", action="store_true",
help="Set this flag if you dont want to run the preprocessing. If this is set then this script "
"will only run the experiment planning and create the plans file")
parser.add_argument("-tl", type=int, required=False, default=8,
help="Number of processes used for preprocessing the low resolution data for the 3D low "
"resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of "
"RAM")
parser.add_argument("-tf", type=int, required=False, default=8,
help="Number of processes used for preprocessing the full resolution data of the 2D U-Net and "
"3D U-Net. Don't overdo it or you will run out of RAM")
parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
help="set this flag to check the dataset integrity. This is useful and should be done once for "
"each dataset!")
args = parser.parse_args()
task_ids = args.task_ids
dont_run_preprocessing = args.no_pp
tl = args.tl
tf = args.tf
planner_name3d = args.planner3d
planner_name2d = args.planner2d
if planner_name3d == "None":
planner_name3d = None
if planner_name2d == "None":
planner_name2d = None
# we need raw data
tasks = []
for i in task_ids:
i = int(i)
task_name = convert_id_to_task_name(i)
if args.verify_dataset_integrity:
verify_dataset_integrity(join(nnFormer_raw_data, task_name))
crop(task_name, False, tf)
tasks.append(task_name)
search_in = join(unetr_pp.__path__[0], "experiment_planning")
if planner_name3d is not None:
planner_3d = recursive_find_python_class([search_in], planner_name3d, current_module="unetr_pp.experiment_planning")
if planner_3d is None:
raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
"unetr_pp.experiment_planning" % planner_name3d)
else:
planner_3d = None
if planner_name2d is not None:
planner_2d = recursive_find_python_class([search_in], planner_name2d, current_module="unetr_pp.experiment_planning")
if planner_2d is None:
raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
"unetr_pp.experiment_planning" % planner_name2d)
else:
planner_2d = None
for t in tasks:
print("\n\n\n", t)
cropped_out_dir = os.path.join(nnFormer_cropped_data, t)
preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t)
#splitted_4d_output_dir_task = os.path.join(nnFormer_raw_data, t)
#lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
# we need to figure out if we need the intensity propoerties. We collect them only if one of the modalities is CT
dataset_json = load_json(join(cropped_out_dir, 'dataset.json'))
modalities = list(dataset_json["modality"].values())
collect_intensityproperties = True if (("CT" in modalities) or ("ct" in modalities)) else False
dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False, num_processes=tf) # this class creates the fingerprint
_ = dataset_analyzer.analyze_dataset(collect_intensityproperties) # this will write output files that will be used by the ExperimentPlanner
maybe_mkdir_p(preprocessing_output_dir_this_task)
shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task)
shutil.copy(join(nnFormer_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task)
threads = (tl, tf)
print("number of threads: ", threads, "\n")
if planner_3d is not None:
exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task)
exp_planner.plan_experiment()
if not dont_run_preprocessing: # double negative, yooo
exp_planner.run_preprocessing(threads)
if planner_2d is not None:
exp_planner = planner_2d(cropped_out_dir, preprocessing_output_dir_this_task)
exp_planner.plan_experiment()
if not dont_run_preprocessing: # double negative, yooo
exp_planner.run_preprocessing(threads)
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/experiment_planning/summarize_plans.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import preprocessing_output_dir
# This file is intended to double check nnFormers design choices. It is intended to be used for developent purposes only
def summarize_plans(file):
plans = load_pickle(file)
print("num_classes: ", plans['num_classes'])
print("modalities: ", plans['modalities'])
print("use_mask_for_norm", plans['use_mask_for_norm'])
print("keep_only_largest_region", plans['keep_only_largest_region'])
print("min_region_size_per_class", plans['min_region_size_per_class'])
print("min_size_per_class", plans['min_size_per_class'])
print("normalization_schemes", plans['normalization_schemes'])
print("stages...\n")
for i in range(len(plans['plans_per_stage'])):
print("stage: ", i)
print(plans['plans_per_stage'][i])
print("")
def write_plans_to_file(f, plans_file):
print(plans_file)
a = load_pickle(plans_file)
stages = list(a['plans_per_stage'].keys())
stages.sort()
for stage in stages:
patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'],
a['plans_per_stage'][stages[stage]]['current_spacing'])]
median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'],
a['plans_per_stage'][stages[stage]]['current_spacing'])]
f.write(plans_file.split("/")[-2])
f.write(";%s" % plans_file.split("/")[-1])
f.write(";%d" % stage)
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size']))
f.write(";%s" % str([str("%03.2f" % i) for i in patch_size_in_mm]))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels']))
f.write(";%s" % str([str("%03.2f" % i) for i in median_patient_size_in_mm]))
f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['current_spacing']]))
f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['original_spacing']]))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes']))
f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes']))
f.write(";%s" % str(a['data_identifier']))
f.write("\n")
if __name__ == "__main__":
base_dir = './'#preprocessing_output_dir''
task_dirs = [i for i in subdirs(base_dir, join=False, prefix="Task") if i.find("BrainTumor") == -1 and i.find("MSSeg") == -1]
print("found %d tasks" % len(task_dirs))
with open("2019_02_06_plans_summary.csv", 'w') as f:
f.write("task;plans_file;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes\n")
for t in task_dirs:
print(t)
tmp = join(base_dir, t)
plans_files = [i for i in subfiles(tmp, suffix=".pkl", join=False) if i.find("_plans_") != -1 and i.find("Dgx2") == -1]
for p in plans_files:
write_plans_to_file(f, join(tmp, p))
f.write("\n")
================================================
FILE: unetr_pp/experiment_planning/utils.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import pickle
import shutil
from collections import OrderedDict
from multiprocessing import Pool
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import join, isdir, maybe_mkdir_p, subfiles, subdirs, isfile
from unetr_pp.configuration import default_num_threads
from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
from unetr_pp.experiment_planning.common_utils import split_4d_nifti
from unetr_pp.paths import nnFormer_raw_data, nnFormer_cropped_data, preprocessing_output_dir
from unetr_pp.preprocessing.cropping import ImageCropper
def split_4d(input_folder, num_processes=default_num_threads, overwrite_task_output_id=None):
assert isdir(join(input_folder, "imagesTr")) and isdir(join(input_folder, "labelsTr")) and \
isfile(join(input_folder, "dataset.json")), \
"The input folder must be a valid Task folder from the Medical Segmentation Decathlon with at least the " \
"imagesTr and labelsTr subfolders and the dataset.json file"
while input_folder.endswith("/"):
input_folder = input_folder[:-1]
full_task_name = input_folder.split("/")[-1]
assert full_task_name.startswith("Task"), "The input folder must point to a folder that starts with TaskXX_"
first_underscore = full_task_name.find("_")
assert first_underscore == 6, "Input folder start with TaskXX with XX being a 3-digit id: 00, 01, 02 etc"
input_task_id = int(full_task_name[4:6])
if overwrite_task_output_id is None:
overwrite_task_output_id = input_task_id
task_name = full_task_name[7:]
output_folder = join(nnFormer_raw_data, "Task%03.0d_" % overwrite_task_output_id + task_name)
if isdir(output_folder):
shutil.rmtree(output_folder)
files = []
output_dirs = []
maybe_mkdir_p(output_folder)
for subdir in ["imagesTr", "imagesTs"]:
curr_out_dir = join(output_folder, subdir)
if not isdir(curr_out_dir):
os.mkdir(curr_out_dir)
curr_dir = join(input_folder, subdir)
nii_files = [join(curr_dir, i) for i in os.listdir(curr_dir) if i.endswith(".nii.gz")]
nii_files.sort()
for n in nii_files:
files.append(n)
output_dirs.append(curr_out_dir)
shutil.copytree(join(input_folder, "labelsTr"), join(output_folder, "labelsTr"))
p = Pool(num_processes)
p.starmap(split_4d_nifti, zip(files, output_dirs))
p.close()
p.join()
shutil.copy(join(input_folder, "dataset.json"), output_folder)
def create_lists_from_splitted_dataset(base_folder_splitted):
lists = []
json_file = join(base_folder_splitted, "dataset.json")
with open(json_file) as jsn:
d = json.load(jsn)
training_files = d['training']
num_modalities = len(d['modality'].keys())
for tr in training_files:
cur_pat = []
for mod in range(num_modalities):
cur_pat.append(join(base_folder_splitted, "imagesTr", tr['image'].split("/")[-1][:-7] +
"_%04.0d.nii.gz" % mod))
cur_pat.append(join(base_folder_splitted, "labelsTr", tr['label'].split("/")[-1]))
lists.append(cur_pat)
return lists, {int(i): d['modality'][str(i)] for i in d['modality'].keys()}
def create_lists_from_splitted_dataset_folder(folder):
"""
does not rely on dataset.json
:param folder:
:return:
"""
caseIDs = get_caseIDs_from_splitted_dataset_folder(folder)
list_of_lists = []
for f in caseIDs:
list_of_lists.append(subfiles(folder, prefix=f, suffix=".nii.gz", join=True, sort=True))
return list_of_lists
def get_caseIDs_from_splitted_dataset_folder(folder):
files = subfiles(folder, suffix=".nii.gz", join=False)
# all files must be .nii.gz and have 4 digit modality index
files = [i[:-12] for i in files]
# only unique patient ids
files = np.unique(files)
return files
def crop(task_string, override=False, num_threads=default_num_threads):
cropped_out_dir = join(nnFormer_cropped_data, task_string)
maybe_mkdir_p(cropped_out_dir)
if override and isdir(cropped_out_dir):
shutil.rmtree(cropped_out_dir)
maybe_mkdir_p(cropped_out_dir)
splitted_4d_output_dir_task = join(nnFormer_raw_data, task_string)
lists, _ = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
imgcrop = ImageCropper(num_threads, cropped_out_dir)
imgcrop.run_cropping(lists, overwrite_existing=override)
shutil.copy(join(nnFormer_raw_data, task_string, "dataset.json"), cropped_out_dir)
def analyze_dataset(task_string, override=False, collect_intensityproperties=True, num_processes=default_num_threads):
cropped_out_dir = join(nnFormer_cropped_data, task_string)
dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=override, num_processes=num_processes)
_ = dataset_analyzer.analyze_dataset(collect_intensityproperties)
def plan_and_preprocess(task_string, processes_lowres=default_num_threads, processes_fullres=3, no_preprocessing=False):
from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D
from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
preprocessing_output_dir_this_task_train = join(preprocessing_output_dir, task_string)
cropped_out_dir = join(nnFormer_cropped_data, task_string)
maybe_mkdir_p(preprocessing_output_dir_this_task_train)
shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task_train)
shutil.copy(join(nnFormer_raw_data, task_string, "dataset.json"), preprocessing_output_dir_this_task_train)
exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task_train)
exp_planner.plan_experiment()
if not no_preprocessing:
exp_planner.run_preprocessing((processes_lowres, processes_fullres))
exp_planner = ExperimentPlanner2D(cropped_out_dir, preprocessing_output_dir_this_task_train)
exp_planner.plan_experiment()
if not no_preprocessing:
exp_planner.run_preprocessing(processes_fullres)
# write which class is in which slice to all training cases (required to speed up 2D Dataloader)
# This is done for all data so that if we wanted to use them with 2D we could do so
if not no_preprocessing:
p = Pool(default_num_threads)
# if there is more than one my_data_identifier (different brnaches) then this code will run for all of them if
# they start with the same string. not problematic, but not pretty
stages = [i for i in subdirs(preprocessing_output_dir_this_task_train, join=True, sort=True)
if i.split("/")[-1].find("stage") != -1]
for s in stages:
print(s.split("/")[-1])
list_of_npz_files = subfiles(s, True, None, ".npz", True)
list_of_pkl_files = [i[:-4]+".pkl" for i in list_of_npz_files]
all_classes = []
for pk in list_of_pkl_files:
with open(pk, 'rb') as f:
props = pickle.load(f)
all_classes_tmp = np.array(props['classes'])
all_classes.append(all_classes_tmp[all_classes_tmp >= 0])
p.map(add_classes_in_slice_info, zip(list_of_npz_files, list_of_pkl_files, all_classes))
p.close()
p.join()
def add_classes_in_slice_info(args):
"""
We need this for 2D dataloader with oversampling. As of now it will detect slices that contain specific classes
at run time, meaning it needs to iterate over an entire patient just to extract one slice. That is obviously bad,
so we are doing this once beforehand and just give the dataloader the info it needs in the patients pkl file.
"""
npz_file, pkl_file, all_classes = args
seg_map = np.load(npz_file)['data'][-1]
with open(pkl_file, 'rb') as f:
props = pickle.load(f)
#if props.get('classes_in_slice_per_axis') is not None:
print(pkl_file)
# this will be a dict of dict where the first dict encodes the axis along which a slice is extracted in its keys.
# The second dict (value of first dict) will have all classes as key and as values a list of all slice ids that
# contain this class
classes_in_slice = OrderedDict()
for axis in range(3):
other_axes = tuple([i for i in range(3) if i != axis])
classes_in_slice[axis] = OrderedDict()
for c in all_classes:
valid_slices = np.where(np.sum(seg_map == c, axis=other_axes) > 0)[0]
classes_in_slice[axis][c] = valid_slices
number_of_voxels_per_class = OrderedDict()
for c in all_classes:
number_of_voxels_per_class[c] = np.sum(seg_map == c)
props['classes_in_slice_per_axis'] = classes_in_slice
props['number_of_voxels_per_class'] = number_of_voxels_per_class
with open(pkl_file, 'wb') as f:
pickle.dump(props, f)
================================================
FILE: unetr_pp/inference/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/inference/predict.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from copy import deepcopy
from typing import Tuple, Union, List
import numpy as np
from batchgenerators.augmentations.utils import resize_segmentation
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax, save_segmentation_nifti
from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Process, Queue
import torch
import SimpleITK as sitk
import shutil
from multiprocessing import Pool
from unetr_pp.postprocessing.connected_components import load_remove_save, load_postprocessing
from unetr_pp.training.model_restore import load_model_and_checkpoint_files
from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc
from unetr_pp.training.network_training.Trainer_synapse import Trainer_synapse
from unetr_pp.utilities.one_hot_encoding import to_one_hot
def preprocess_save_to_queue(preprocess_fn, q, list_of_lists, output_files, segs_from_prev_stage, classes,
transpose_forward):
# suppress output
# sys.stdout = open(os.devnull, 'w')
errors_in = []
for i, l in enumerate(list_of_lists):
try:
output_file = output_files[i]
print("preprocessing", output_file)
d, _, dct = preprocess_fn(l)
# print(output_file, dct)
if segs_from_prev_stage[i] is not None:
assert isfile(segs_from_prev_stage[i]) and segs_from_prev_stage[i].endswith(
".nii.gz"), "segs_from_prev_stage" \
" must point to a " \
"segmentation file"
seg_prev = sitk.GetArrayFromImage(sitk.ReadImage(segs_from_prev_stage[i]))
# check to see if shapes match
img = sitk.GetArrayFromImage(sitk.ReadImage(l[0]))
assert all([i == j for i, j in zip(seg_prev.shape, img.shape)]), "image and segmentation from previous " \
"stage don't have the same pixel array " \
"shape! image: %s, seg_prev: %s" % \
(l[0], segs_from_prev_stage[i])
seg_prev = seg_prev.transpose(transpose_forward)
seg_reshaped = resize_segmentation(seg_prev, d.shape[1:], order=1, cval=0)
seg_reshaped = to_one_hot(seg_reshaped, classes)
d = np.vstack((d, seg_reshaped)).astype(np.float32)
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
print(d.shape)
if np.prod(d.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save, 4 because float32 is 4 bytes
print(
"This output is too large for python process-process communication. "
"Saving output temporarily to disk")
np.save(output_file[:-7] + ".npy", d)
d = output_file[:-7] + ".npy"
q.put((output_file, (d, dct)))
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
print("error in", l)
print(e)
q.put("end")
if len(errors_in) > 0:
print("There were some errors in the following cases:", errors_in)
print("These cases were ignored.")
else:
print("This worker has ended successfully, no errors to report")
# restore output
# sys.stdout = sys.__stdout__
def preprocess_multithreaded(trainer, list_of_lists, output_files, num_processes=2, segs_from_prev_stage=None):
if segs_from_prev_stage is None:
segs_from_prev_stage = [None] * len(list_of_lists)
num_processes = min(len(list_of_lists), num_processes)
classes = list(range(1, trainer.num_classes))
assert isinstance(trainer, Trainer_acdc) or isinstance(trainer, Trainer_synapse)
q = Queue(1)
processes = []
for i in range(num_processes):
pr = Process(target=preprocess_save_to_queue, args=(trainer.preprocess_patient, q,
list_of_lists[i::num_processes],
output_files[i::num_processes],
segs_from_prev_stage[i::num_processes],
classes, trainer.plans['transpose_forward']))
pr.start()
processes.append(pr)
try:
end_ctr = 0
while end_ctr != num_processes:
item = q.get()
if item == "end":
end_ctr += 1
continue
else:
yield item
finally:
for p in processes:
if p.is_alive():
p.terminate() # this should not happen but better safe than sorry right
p.join()
q.close()
def predict_cases(model, list_of_lists, output_filenames, folds, save_npz, num_threads_preprocessing,
num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False,
all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint",
segmentation_export_kwargs: dict = None):
"""
:param segmentation_export_kwargs:
:param model: folder where the model is saved, must contain fold_x subfolders
:param list_of_lists: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...]
:param output_filenames: [output_file_case0.nii.gz, output_file_case1.nii.gz, ...]
:param folds: default: (0, 1, 2, 3, 4) (but can also be 'all' or a subset of the five folds, for example use (0, )
for using only fold_0
:param save_npz: default: False
:param num_threads_preprocessing:
:param num_threads_nifti_save:
:param segs_from_prev_stage:
:param do_tta: default: True, can be set to False for a 8x speedup at the cost of a reduced segmentation quality
:param overwrite_existing: default: True
:param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
:return:
"""
assert len(list_of_lists) == len(output_filenames)
if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
pool = Pool(num_threads_nifti_save)
results = []
cleaned_output_files = []
for o in output_filenames:
dr, f = os.path.split(o)
if len(dr) > 0:
maybe_mkdir_p(dr)
if not f.endswith(".nii.gz"):
f, _ = os.path.splitext(f)
f = f + ".nii.gz"
cleaned_output_files.append(join(dr, f))
if not overwrite_existing:
print("number of cases:", len(list_of_lists))
# if save_npz=True then we should also check for missing npz files
not_done_idx = [i for i, j in enumerate(cleaned_output_files) if (not isfile(j)) or (save_npz and not isfile(j[:-7] + '.npz'))]
cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
list_of_lists = [list_of_lists[i] for i in not_done_idx]
if segs_from_prev_stage is not None:
segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
print("number of cases that still need to be predicted:", len(cleaned_output_files))
print("emptying cuda cache")
torch.cuda.empty_cache()
print("loading parameters for folds,", folds)
trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name)
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in trainer.plans.keys():
force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
print("starting preprocessing generator")
preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
segs_from_prev_stage)
print("starting prediction...")
all_output_files = []
for preprocessed in preprocessing:
output_filename, (d, dct) = preprocessed
all_output_files.append(all_output_files)
if isinstance(d, str):
data = np.load(d)
os.remove(d)
d = data
print("predicting", output_filename)
softmax = []
for p in params:
trainer.load_checkpoint_ram(p, False)
softmax.append(trainer.predict_preprocessed_data_return_seg_and_softmax(
d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True,
step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu,
mixed_precision=mixed_precision)[1][None])
softmax = np.vstack(softmax)
softmax_mean = np.mean(softmax, 0)
transpose_forward = trainer.plans.get('transpose_forward')
if transpose_forward is not None:
transpose_backward = trainer.plans.get('transpose_backward')
softmax_mean = softmax_mean.transpose([0] + [i + 1 for i in transpose_backward])
if save_npz:
npz_file = output_filename[:-7] + ".npz"
else:
npz_file = None
if hasattr(trainer, 'regions_class_order'):
region_class_order = trainer.regions_class_order
else:
region_class_order = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
bytes_per_voxel = 4
if all_in_gpu:
bytes_per_voxel = 2 # if all_in_gpu then the return value is half (float16)
if np.prod(softmax_mean.shape) > (2e9 / bytes_per_voxel * 0.85): # * 0.85 just to be save
print(
"This output is too large for python process-process communication. Saving output temporarily to disk")
np.save(output_filename[:-7] + ".npy", softmax_mean)
softmax_mean = output_filename[:-7] + ".npy"
results.append(pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_mean, output_filename, dct, interpolation_order, region_class_order,
None, None,
npz_file, None, force_separate_z, interpolation_order_z),)
))
print("inference done. Now waiting for the segmentation export to finish...")
_ = [i.get() for i in results]
# now apply postprocessing
# first load the postprocessing properties if they are present. Else raise a well visible warning
results = []
pp_file = join(model, "postprocessing.json")
if isfile(pp_file):
print("postprocessing...")
shutil.copy(pp_file, os.path.abspath(os.path.dirname(output_filenames[0])))
# for_which_classes stores for which of the classes everything but the largest connected component needs to be
# removed
for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
results.append(pool.starmap_async(load_remove_save,
zip(output_filenames, output_filenames,
[for_which_classes] * len(output_filenames),
[min_valid_obj_size] * len(output_filenames))))
_ = [i.get() for i in results]
else:
print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
"consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
"%s" % model)
pool.close()
pool.join()
def predict_cases_fast(model, list_of_lists, output_filenames, folds, num_threads_preprocessing,
num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True,
overwrite_existing=False,
all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint",
segmentation_export_kwargs: dict = None):
assert len(list_of_lists) == len(output_filenames)
if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
pool = Pool(num_threads_nifti_save)
results = []
cleaned_output_files = []
for o in output_filenames:
dr, f = os.path.split(o)
if len(dr) > 0:
maybe_mkdir_p(dr)
if not f.endswith(".nii.gz"):
f, _ = os.path.splitext(f)
f = f + ".nii.gz"
cleaned_output_files.append(join(dr, f))
if not overwrite_existing:
print("number of cases:", len(list_of_lists))
not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)]
cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
list_of_lists = [list_of_lists[i] for i in not_done_idx]
if segs_from_prev_stage is not None:
segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
print("number of cases that still need to be predicted:", len(cleaned_output_files))
print("emptying cuda cache")
torch.cuda.empty_cache()
print("loading parameters for folds,", folds)
trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name)
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in trainer.plans.keys():
force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
print("starting preprocessing generator")
preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
segs_from_prev_stage)
print("starting prediction...")
for preprocessed in preprocessing:
print("getting data from preprocessor")
output_filename, (d, dct) = preprocessed
print("got something")
if isinstance(d, str):
print("what I got is a string, so I need to load a file")
data = np.load(d)
os.remove(d)
d = data
# preallocate the output arrays
# same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time)
softmax_aggr = None # np.zeros((trainer.num_classes, *d.shape[1:]), dtype=np.float16)
all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int)
print("predicting", output_filename)
for i, p in enumerate(params):
trainer.load_checkpoint_ram(p, False)
res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta,
mirror_axes=trainer.data_aug_params['mirror_axes'],
use_sliding_window=True,
step_size=step_size, use_gaussian=True,
all_in_gpu=all_in_gpu,
mixed_precision=mixed_precision)
if len(params) > 1:
# otherwise we dont need this and we can save ourselves the time it takes to copy that
print("aggregating softmax")
if softmax_aggr is None:
softmax_aggr = res[1]
else:
softmax_aggr += res[1]
all_seg_outputs[i] = res[0]
print("obtaining segmentation map")
if len(params) > 1:
# we dont need to normalize the softmax by 1 / len(params) because this would not change the outcome of the argmax
seg = softmax_aggr.argmax(0)
else:
seg = all_seg_outputs[0]
print("applying transpose_backward")
transpose_forward = trainer.plans.get('transpose_forward')
if transpose_forward is not None:
transpose_backward = trainer.plans.get('transpose_backward')
seg = seg.transpose([i for i in transpose_backward])
print("initializing segmentation export")
results.append(pool.starmap_async(save_segmentation_nifti,
((seg, output_filename, dct, interpolation_order, force_separate_z,
interpolation_order_z),)
))
print("done")
print("inference done. Now waiting for the segmentation export to finish...")
_ = [i.get() for i in results]
# now apply postprocessing
# first load the postprocessing properties if they are present. Else raise a well visible warning
results = []
pp_file = join(model, "postprocessing.json")
if isfile(pp_file):
print("postprocessing...")
shutil.copy(pp_file, os.path.dirname(output_filenames[0]))
# for_which_classes stores for which of the classes everything but the largest connected component needs to be
# removed
for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
results.append(pool.starmap_async(load_remove_save,
zip(output_filenames, output_filenames,
[for_which_classes] * len(output_filenames),
[min_valid_obj_size] * len(output_filenames))))
_ = [i.get() for i in results]
else:
print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
"consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
"%s" % model)
pool.close()
pool.join()
def predict_cases_fastest(model, list_of_lists, output_filenames, folds, num_threads_preprocessing,
num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True,
overwrite_existing=False, all_in_gpu=True, step_size=0.5,
checkpoint_name="model_final_checkpoint"):
assert len(list_of_lists) == len(output_filenames)
if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
pool = Pool(num_threads_nifti_save)
results = []
cleaned_output_files = []
for o in output_filenames:
dr, f = os.path.split(o)
if len(dr) > 0:
maybe_mkdir_p(dr)
if not f.endswith(".nii.gz"):
f, _ = os.path.splitext(f)
f = f + ".nii.gz"
cleaned_output_files.append(join(dr, f))
if not overwrite_existing:
print("number of cases:", len(list_of_lists))
not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)]
cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
list_of_lists = [list_of_lists[i] for i in not_done_idx]
if segs_from_prev_stage is not None:
segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
print("number of cases that still need to be predicted:", len(cleaned_output_files))
print("emptying cuda cache")
torch.cuda.empty_cache()
print("loading parameters for folds,", folds)
trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name)
print("starting preprocessing generator")
preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
segs_from_prev_stage)
print("starting prediction...")
for preprocessed in preprocessing:
print("getting data from preprocessor")
output_filename, (d, dct) = preprocessed
print("got something")
if isinstance(d, str):
print("what I got is a string, so I need to load a file")
data = np.load(d)
os.remove(d)
d = data
# preallocate the output arrays
# same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time)
all_softmax_outputs = np.zeros((len(params), trainer.num_classes, *d.shape[1:]), dtype=np.float16)
all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int)
print("predicting", output_filename)
for i, p in enumerate(params):
trainer.load_checkpoint_ram(p, False)
res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta,
mirror_axes=trainer.data_aug_params['mirror_axes'],
use_sliding_window=True,
step_size=step_size, use_gaussian=True,
all_in_gpu=all_in_gpu,
mixed_precision=mixed_precision)
if len(params) > 1:
# otherwise we dont need this and we can save ourselves the time it takes to copy that
all_softmax_outputs[i] = res[1]
all_seg_outputs[i] = res[0]
print("aggregating predictions")
if len(params) > 1:
softmax_mean = np.mean(all_softmax_outputs, 0)
seg = softmax_mean.argmax(0)
else:
seg = all_seg_outputs[0]
print("applying transpose_backward")
transpose_forward = trainer.plans.get('transpose_forward')
if transpose_forward is not None:
transpose_backward = trainer.plans.get('transpose_backward')
seg = seg.transpose([i for i in transpose_backward])
print("initializing segmentation export")
results.append(pool.starmap_async(save_segmentation_nifti,
((seg, output_filename, dct, 0, None),)
))
print("done")
print("inference done. Now waiting for the segmentation export to finish...")
_ = [i.get() for i in results]
# now apply postprocessing
# first load the postprocessing properties if they are present. Else raise a well visible warning
results = []
pp_file = join(model, "postprocessing.json")
if isfile(pp_file):
print("postprocessing...")
shutil.copy(pp_file, os.path.dirname(output_filenames[0]))
# for_which_classes stores for which of the classes everything but the largest connected component needs to be
# removed
for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
results.append(pool.starmap_async(load_remove_save,
zip(output_filenames, output_filenames,
[for_which_classes] * len(output_filenames),
[min_valid_obj_size] * len(output_filenames))))
_ = [i.get() for i in results]
else:
print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
"consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
"%s" % model)
pool.close()
pool.join()
def check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities):
print("This model expects %d input modalities for each image" % expected_num_modalities)
files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True)
maybe_case_ids = np.unique([i[:-12] for i in files])
remaining = deepcopy(files)
missing = []
assert len(files) > 0, "input folder did not contain any images (expected to find .nii.gz file endings)"
# now check if all required files are present and that no unexpected files are remaining
for c in maybe_case_ids:
for n in range(expected_num_modalities):
expected_output_file = c + "_%04.0d.nii.gz" % n
if not isfile(join(input_folder, expected_output_file)):
missing.append(expected_output_file)
else:
remaining.remove(expected_output_file)
print("Found %d unique case ids, here are some examples:" % len(maybe_case_ids),
np.random.choice(maybe_case_ids, min(len(maybe_case_ids), 10)))
print("If they don't look right, make sure to double check your filenames. They must end with _0000.nii.gz etc")
if len(remaining) > 0:
print("found %d unexpected remaining files in the folder. Here are some examples:" % len(remaining),
np.random.choice(remaining, min(len(remaining), 10)))
if len(missing) > 0:
print("Some files are missing:")
print(missing)
raise RuntimeError("missing files in input_folder")
return maybe_case_ids
def predict_from_folder(model: str, input_folder: str, output_folder: str, folds: Union[Tuple[int], List[int]],
save_npz: bool, num_threads_preprocessing: int, num_threads_nifti_save: int,
lowres_segmentations: Union[str, None],
part_id: int, num_parts: int, tta: bool, mixed_precision: bool = True,
overwrite_existing: bool = True, mode: str = 'normal', overwrite_all_in_gpu: bool = None,
step_size: float = 0.5, checkpoint_name: str = "model_final_checkpoint",
segmentation_export_kwargs: dict = None):
"""
here we use the standard naming scheme to generate list_of_lists and output_files needed by predict_cases
:param model:
:param input_folder:
:param output_folder:
:param folds:
:param save_npz:
:param num_threads_preprocessing:
:param num_threads_nifti_save:
:param lowres_segmentations:
:param part_id:
:param num_parts:
:param tta:
:param mixed_precision:
:param overwrite_existing: if not None then it will be overwritten with whatever is in there. None is default (no overwrite)
:return:
"""
maybe_mkdir_p(output_folder)
shutil.copy(join(model, 'plans.pkl'), output_folder)
assert isfile(join(model, "plans.pkl")), "Folder with saved model weights must contain a plans.pkl file"
expected_num_modalities = load_pickle(join(model, "plans.pkl"))['num_modalities']
# check input folder integrity
case_ids = check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities)
output_files = [join(output_folder, i + ".nii.gz") for i in case_ids]
all_files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True)
list_of_lists = [[join(input_folder, i) for i in all_files if i[:len(j)].startswith(j) and
len(i) == (len(j) + 12)] for j in case_ids]
if lowres_segmentations is not None:
assert isdir(lowres_segmentations), "if lowres_segmentations is not None then it must point to a directory"
lowres_segmentations = [join(lowres_segmentations, i + ".nii.gz") for i in case_ids]
assert all([isfile(i) for i in lowres_segmentations]), "not all lowres_segmentations files are present. " \
"(I was searching for case_id.nii.gz in that folder)"
lowres_segmentations = lowres_segmentations[part_id::num_parts]
else:
lowres_segmentations = None
if mode == "normal":
if overwrite_all_in_gpu is None:
all_in_gpu = False
else:
all_in_gpu = overwrite_all_in_gpu
return predict_cases(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta,
mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu,
step_size=step_size, checkpoint_name=checkpoint_name,
segmentation_export_kwargs=segmentation_export_kwargs)
elif mode == "fast":
if overwrite_all_in_gpu is None:
all_in_gpu = True
else:
all_in_gpu = overwrite_all_in_gpu
assert save_npz is False
return predict_cases_fast(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations,
tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu,
step_size=step_size, checkpoint_name=checkpoint_name,
segmentation_export_kwargs=segmentation_export_kwargs)
elif mode == "fastest":
if overwrite_all_in_gpu is None:
all_in_gpu = True
else:
all_in_gpu = overwrite_all_in_gpu
assert save_npz is False
return predict_cases_fastest(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations,
tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu,
step_size=step_size, checkpoint_name=checkpoint_name)
else:
raise ValueError("unrecognized mode. Must be normal, fast or fastest")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct"
" order (same as training). Files must be named "
"CASENAME_XXXX.nii.gz where XXXX is the modality "
"identifier (0000, 0001, etc)", required=True)
parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions")
parser.add_argument('-m', '--model_output_folder',
help='model output folder. Will automatically discover the folds '
'that were '
'run and use those as an ensemble', required=True)
parser.add_argument('-f', '--folds', nargs='+', default='None', help="folds to use for prediction. Default is None "
"which means that folds will be detected "
"automatically in the model output folder")
parser.add_argument('-z', '--save_npz', required=False, action='store_true', help="use this if you want to ensemble"
" these predictions with those of"
" other models. Softmax "
"probabilities will be saved as "
"compresed numpy arrays in "
"output_folder and can be merged "
"between output_folders with "
"merge_predictions.py")
parser.add_argument('-l', '--lowres_segmentations', required=False, default='None', help="if model is the highres "
"stage of the cascade then you need to use -l to specify where the segmentations of the "
"corresponding lowres unet are. Here they are required to do a prediction")
parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (for example via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_parts", type=int, required=False, default=1,
help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help=
"Determines many background processes will be used for data preprocessing. Reduce this if you "
"run into out of memory (RAM) problems. Default: 6")
parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help=
"Determines many background processes will be used for segmentation export. Reduce this if you "
"run into out of memory (RAM) problems. Default: 2")
parser.add_argument("--tta", required=False, type=int, default=1, help="Set to 0 to disable test time data "
"augmentation (speedup of factor "
"4(2D)/8(3D)), "
"lower quality segmentations")
parser.add_argument("--overwrite_existing", required=False, type=int, default=1, help="Set this to 0 if you need "
"to resume a previous "
"prediction. Default: 1 "
"(=existing segmentations "
"in output_folder will be "
"overwritten)")
parser.add_argument("--mode", type=str, default="normal", required=False)
parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True")
parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
# help="order of interpolation for segmentations, has no effect if mode=fastest")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
# help="order of interpolation along z is z is done differently")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
# help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest")
parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False,
help='Predictions are done with mixed precision by default. This improves speed and reduces '
'the required vram. If you want to disable mixed precision you can set this flag. Note '
'that yhis is not recommended (mixed precision is ~2x faster!)')
args = parser.parse_args()
input_folder = args.input_folder
output_folder = args.output_folder
part_id = args.part_id
num_parts = args.num_parts
model = args.model_output_folder
folds = args.folds
save_npz = args.save_npz
lowres_segmentations = args.lowres_segmentations
num_threads_preprocessing = args.num_threads_preprocessing
num_threads_nifti_save = args.num_threads_nifti_save
tta = args.tta
step_size = args.step_size
# interp_order = args.interp_order
# interp_order_z = args.interp_order_z
# force_separate_z = args.force_separate_z
# if force_separate_z == "None":
# force_separate_z = None
# elif force_separate_z == "False":
# force_separate_z = False
# elif force_separate_z == "True":
# force_separate_z = True
# else:
# raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
overwrite = args.overwrite_existing
mode = args.mode
all_in_gpu = args.all_in_gpu
if lowres_segmentations == "None":
lowres_segmentations = None
if isinstance(folds, list):
if folds[0] == 'all' and len(folds) == 1:
pass
else:
folds = [int(i) for i in folds]
elif folds == "None":
folds = None
else:
raise ValueError("Unexpected value for argument folds")
if tta == 0:
tta = False
elif tta == 1:
tta = True
else:
raise ValueError("Unexpected value for tta, Use 1 or 0")
if overwrite == 0:
overwrite = False
elif overwrite == 1:
overwrite = True
else:
raise ValueError("Unexpected value for overwrite, Use 1 or 0")
assert all_in_gpu in ['None', 'False', 'True']
if all_in_gpu == "None":
all_in_gpu = None
elif all_in_gpu == "True":
all_in_gpu = True
elif all_in_gpu == "False":
all_in_gpu = False
predict_from_folder(model, input_folder, output_folder, folds, save_npz, num_threads_preprocessing,
num_threads_nifti_save, lowres_segmentations, part_id, num_parts, tta, mixed_precision=not args.disable_mixed_precision,
overwrite_existing=overwrite, mode=mode, overwrite_all_in_gpu=all_in_gpu, step_size=step_size)
================================================
FILE: unetr_pp/inference/predict_simple.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from unetr_pp.inference.predict import predict_from_folder
from unetr_pp.paths import default_plans_identifier, network_training_output_dir, default_trainer
from batchgenerators.utilities.file_and_folder_operations import join, isdir
from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name
from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
from unetr_pp.run.default_configuration import get_default_configuration
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct"
" order (same as training). Files must be named "
"CASENAME_XXXX.nii.gz where XXXX is the modality "
"identifier (0000, 0001, etc)", required=True)
parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions")
parser.add_argument('-t', '--task_name', help='task name or task ID, required.',
default=default_plans_identifier, required=True)
parser.add_argument('-tr', '--trainer_class_name',
help='Name of the nnFormerTrainer used for 2D U-Net, full resolution 3D U-Net and low resolution '
'U-Net. The default is %s. If you are running inference with the cascade and the folder '
'pointed to by --lowres_segmentations does not contain the segmentation maps generated by '
'the low resolution U-Net then the low resolution segmentation maps will be automatically '
'generated. For this case, make sure to set the trainer class here that matches your '
'--cascade_trainer_class_name (this part can be ignored if defaults are used).'
% default_trainer,
required=False,
default=default_trainer)
parser.add_argument('-ctr', '--cascade_trainer_class_name',
help="Trainer class name used for predicting the 3D full resolution U-Net part of the cascade."
"Default is %s" % default_trainer, required=False,
default=default_trainer)
parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres",
default="3d_fullres", required=False)
parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing',
default=default_plans_identifier, required=False)
parser.add_argument('-f', '--folds', nargs='+', default='None',
help="folds to use for prediction. Default is None which means that folds will be detected "
"automatically in the model output folder")
parser.add_argument('-z', '--save_npz', required=False, action='store_true',
help="use this if you want to ensemble these predictions with those of other models. Softmax "
"probabilities will be saved as compressed numpy arrays in output_folder and can be "
"merged between output_folders with nnFormer_ensemble_predictions")
parser.add_argument('-l', '--lowres_segmentations', required=False, default='None',
help="if model is the highres stage of the cascade then you can use this folder to provide "
"predictions from the low resolution 3D U-Net. If this is left at default, the "
"predictions will be generated automatically (provided that the 3D low resolution U-Net "
"network weights are present")
parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (for example via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_parts", type=int, required=False, default=1,
help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help=
"Determines many background processes will be used for data preprocessing. Reduce this if you "
"run into out of memory (RAM) problems. Default: 6")
parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help=
"Determines many background processes will be used for segmentation export. Reduce this if you "
"run into out of memory (RAM) problems. Default: 2")
parser.add_argument("--disable_tta", required=False, default=False, action="store_true",
help="set this flag to disable test time data augmentation via mirroring. Speeds up inference "
"by roughly factor 4 (2D) or 8 (3D)")
parser.add_argument("--overwrite_existing", required=False, default=False, action="store_true",
help="Set this flag if the target folder contains predictions that you would like to overwrite")
parser.add_argument("--mode", type=str, default="normal", required=False, help="Hands off!")
parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True. "
"Do not touch.")
parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
# help="order of interpolation for segmentations, has no effect if mode=fastest. Do not touch this.")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
# help="order of interpolation along z is z is done differently. Do not touch this.")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
# help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest. "
# "Do not touch this.")
parser.add_argument('-chk',
help='checkpoint name, default: model_final_checkpoint',
required=False,
default='model_final_checkpoint')
parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False,
help='Predictions are done with mixed precision by default. This improves speed and reduces '
'the required vram. If you want to disable mixed precision you can set this flag. Note '
'that yhis is not recommended (mixed precision is ~2x faster!)')
args = parser.parse_args()
input_folder = args.input_folder
output_folder = args.output_folder
part_id = args.part_id
num_parts = args.num_parts
folds = args.folds
save_npz = args.save_npz
lowres_segmentations = args.lowres_segmentations
num_threads_preprocessing = args.num_threads_preprocessing
num_threads_nifti_save = args.num_threads_nifti_save
disable_tta = args.disable_tta
step_size = args.step_size
# interp_order = args.interp_order
# interp_order_z = args.interp_order_z
# force_separate_z = args.force_separate_z
overwrite_existing = args.overwrite_existing
mode = args.mode
all_in_gpu = args.all_in_gpu
model = args.model
trainer_class_name = args.trainer_class_name
cascade_trainer_class_name = args.cascade_trainer_class_name
task_name = args.task_name
if not task_name.startswith("Task"):
task_id = int(task_name)
task_name = convert_id_to_task_name(task_id)
assert model in ["2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"], "-m must be 2d, 3d_lowres, 3d_fullres or " \
"3d_cascade_fullres"
# if force_separate_z == "None":
# force_separate_z = None
# elif force_separate_z == "False":
# force_separate_z = False
# elif force_separate_z == "True":
# force_separate_z = True
# else:
# raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
if lowres_segmentations == "None":
lowres_segmentations = None
if isinstance(folds, list):
if folds[0] == 'all' and len(folds) == 1:
pass
else:
folds = [int(i) for i in folds]
elif folds == "None":
folds = None
else:
raise ValueError("Unexpected value for argument folds")
assert all_in_gpu in ['None', 'False', 'True']
if all_in_gpu == "None":
all_in_gpu = None
elif all_in_gpu == "True":
all_in_gpu = True
elif all_in_gpu == "False":
all_in_gpu = False
# we need to catch the case where model is 3d cascade fullres and the low resolution folder has not been set.
# In that case we need to try and predict with 3d low res first
if model == "3d_cascade_fullres" and lowres_segmentations is None:
print("lowres_segmentations is None. Attempting to predict 3d_lowres first...")
assert part_id == 0 and num_parts == 1, "if you don't specify a --lowres_segmentations folder for the " \
"inference of the cascade, custom values for part_id and num_parts " \
"are not supported. If you wish to have multiple parts, please " \
"run the 3d_lowres inference first (separately)"
model_folder_name = join(network_training_output_dir, "3d_lowres", task_name, trainer_class_name + "__" +
args.plans_identifier)
assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name
lowres_output_folder = join(output_folder, "3d_lowres_predictions")
predict_from_folder(model_folder_name, input_folder, lowres_output_folder, folds, False,
num_threads_preprocessing, num_threads_nifti_save, None, part_id, num_parts, not disable_tta,
overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,
mixed_precision=not args.disable_mixed_precision,
step_size=step_size)
lowres_segmentations = lowres_output_folder
torch.cuda.empty_cache()
print("3d_lowres done")
if model == "3d_cascade_fullres":
trainer = cascade_trainer_class_name
else:
trainer = trainer_class_name
model_folder_name = join(network_training_output_dir, model, task_name, trainer + "__" +
args.plans_identifier)
print("using model stored in ", model_folder_name)
assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name
import os
import shutil
if not os.path.exists(join(model_folder_name, "plans.pkl")):
plans_file = join(preprocessing_output_dir, task_name, args.plans_identifier + "_plans_3D.pkl")
shutil.copy(plans_file, join(model_folder_name, "plans.pkl"))
_=get_default_configuration(model, task_name, trainer_class_name, args.plans_identifier)
predict_from_folder(model_folder_name, input_folder, output_folder, folds, save_npz, num_threads_preprocessing,
num_threads_nifti_save, lowres_segmentations, part_id, num_parts, not disable_tta,
overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,
mixed_precision=not args.disable_mixed_precision,
step_size=step_size, checkpoint_name=args.chk)
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/inference/segmentation_export.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from copy import deepcopy
from typing import Union, Tuple
import numpy as np
import SimpleITK as sitk
from batchgenerators.augmentations.utils import resize_segmentation
from unetr_pp.preprocessing.preprocessing import get_lowres_axis, get_do_separate_z, resample_data_or_seg
from batchgenerators.utilities.file_and_folder_operations import *
def save_segmentation_nifti_from_softmax(segmentation_softmax: Union[str, np.ndarray], out_fname: str,
properties_dict: dict, order: int = 1,
region_class_order: Tuple[Tuple[int]] = None,
seg_postprogess_fn: callable = None, seg_postprocess_args: tuple = None,
resampled_npz_fname: str = None,
non_postprocessed_fname: str = None, force_separate_z: bool = None,
interpolation_order_z: int = 0, verbose: bool = True):
"""
This is a utility for writing segmentations to nifto and npz. It requires the data to have been preprocessed by
GenericPreprocessor because it depends on the property dictionary output (dct) to know the geometry of the original
data. segmentation_softmax does not have to have the same size in pixels as the original data, it will be
resampled to match that. This is generally useful because the spacings our networks operate on are most of the time
not the native spacings of the image data.
If seg_postprogess_fn is not None then seg_postprogess_fnseg_postprogess_fn(segmentation, *seg_postprocess_args)
will be called before nifto export
There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code.) We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray for segmentation_softmax and will handle this automatically
:param segmentation_softmax:
:param out_fname:
:param properties_dict:
:param order:
:param region_class_order:
:param seg_postprogess_fn:
:param seg_postprocess_args:
:param resampled_npz_fname:
:param non_postprocessed_fname:
:param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always
/never resample along z separately. Do not touch unless you know what you are doing
:param interpolation_order_z: if separate z resampling is done then this is the order for resampling in z
:param verbose:
:return:
"""
if verbose: print("force_separate_z:", force_separate_z, "interpolation order:", order)
if isinstance(segmentation_softmax, str):
assert isfile(segmentation_softmax), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(segmentation_softmax)
segmentation_softmax = np.load(segmentation_softmax)
os.remove(del_file)
# first resample, then put result into bbox of cropping, then save
current_shape = segmentation_softmax.shape
shape_original_after_cropping = properties_dict.get('size_after_cropping')
shape_original_before_cropping = properties_dict.get('original_size_of_raw_data')
# current_spacing = dct.get('spacing_after_resampling')
# original_spacing = dct.get('original_spacing')
if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original_after_cropping))]):
if force_separate_z is None:
if get_do_separate_z(properties_dict.get('original_spacing')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
elif get_do_separate_z(properties_dict.get('spacing_after_resampling')):
do_separate_z = True
lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling'))
else:
do_separate_z = False
lowres_axis = None
else:
do_separate_z = force_separate_z
if do_separate_z:
lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
else:
lowres_axis = None
if lowres_axis is not None and len(lowres_axis) != 1:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
if verbose: print("separate z:", do_separate_z, "lowres axis", lowres_axis)
seg_old_spacing = resample_data_or_seg(segmentation_softmax, shape_original_after_cropping, is_seg=False,
axis=lowres_axis, order=order, do_separate_z=do_separate_z, cval=0,
order_z=interpolation_order_z)
# seg_old_spacing = resize_softmax_output(segmentation_softmax, shape_original_after_cropping, order=order)
else:
if verbose: print("no resampling necessary")
seg_old_spacing = segmentation_softmax
if resampled_npz_fname is not None:
np.savez_compressed(resampled_npz_fname, softmax=seg_old_spacing.astype(np.float16))
# this is needed for ensembling if the nonlinearity is sigmoid
if region_class_order is not None:
properties_dict['regions_class_order'] = region_class_order
save_pickle(properties_dict, resampled_npz_fname[:-4] + ".pkl")
if region_class_order is None:
seg_old_spacing = seg_old_spacing.argmax(0)
else:
seg_old_spacing_final = np.zeros(seg_old_spacing.shape[1:])
for i, c in enumerate(region_class_order):
seg_old_spacing_final[seg_old_spacing[i] > 0.5] = c
seg_old_spacing = seg_old_spacing_final
bbox = properties_dict.get('crop_bbox')
if bbox is not None:
seg_old_size = np.zeros(shape_original_before_cropping)
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
seg_old_size[bbox[0][0]:bbox[0][1],
bbox[1][0]:bbox[1][1],
bbox[2][0]:bbox[2][1]] = seg_old_spacing
else:
seg_old_size = seg_old_spacing
if seg_postprogess_fn is not None:
seg_old_size_postprocessed = seg_postprogess_fn(np.copy(seg_old_size), *seg_postprocess_args)
else:
seg_old_size_postprocessed = seg_old_size
seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(np.uint8))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_fname)
if (non_postprocessed_fname is not None) and (seg_postprogess_fn is not None):
seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
seg_resized_itk.SetDirection(properties_dict['itk_direction'])
sitk.WriteImage(seg_resized_itk, non_postprocessed_fname)
def save_segmentation_nifti(segmentation, out_fname, dct, order=1, force_separate_z=None, order_z=0):
"""
faster and uses less ram than save_segmentation_nifti_from_softmax, but maybe less precise and also does not support
softmax export (which is needed for ensembling). So it's a niche function that may be useful in some cases.
:param segmentation:
:param out_fname:
:param dct:
:param order:
:param force_separate_z:
:return:
"""
# suppress output
print("force_separate_z:", force_separate_z, "interpolation order:", order)
sys.stdout = open(os.devnull, 'w')
if isinstance(segmentation, str):
assert isfile(segmentation), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(segmentation)
segmentation = np.load(segmentation)
os.remove(del_file)
# first resample, then put result into bbox of cropping, then save
current_shape = segmentation.shape
shape_original_after_cropping = dct.get('size_after_cropping')
shape_original_before_cropping = dct.get('original_size_of_raw_data')
# current_spacing = dct.get('spacing_after_resampling')
# original_spacing = dct.get('original_spacing')
if np.any(np.array(current_shape) != np.array(shape_original_after_cropping)):
if order == 0:
seg_old_spacing = resize_segmentation(segmentation, shape_original_after_cropping, 0, 0)
else:
if force_separate_z is None:
if get_do_separate_z(dct.get('original_spacing')):
do_separate_z = True
lowres_axis = get_lowres_axis(dct.get('original_spacing'))
elif get_do_separate_z(dct.get('spacing_after_resampling')):
do_separate_z = True
lowres_axis = get_lowres_axis(dct.get('spacing_after_resampling'))
else:
do_separate_z = False
lowres_axis = None
else:
do_separate_z = force_separate_z
if do_separate_z:
lowres_axis = get_lowres_axis(dct.get('original_spacing'))
else:
lowres_axis = None
print("separate z:", do_separate_z, "lowres axis", lowres_axis)
seg_old_spacing = resample_data_or_seg(segmentation[None], shape_original_after_cropping, is_seg=True,
axis=lowres_axis, order=order, do_separate_z=do_separate_z, cval=0,
order_z=order_z)[0]
else:
seg_old_spacing = segmentation
bbox = dct.get('crop_bbox')
if bbox is not None:
seg_old_size = np.zeros(shape_original_before_cropping)
for c in range(3):
bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
seg_old_size[bbox[0][0]:bbox[0][1],
bbox[1][0]:bbox[1][1],
bbox[2][0]:bbox[2][1]] = seg_old_spacing
else:
seg_old_size = seg_old_spacing
seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
seg_resized_itk.SetSpacing(dct['itk_spacing'])
seg_resized_itk.SetOrigin(dct['itk_origin'])
seg_resized_itk.SetDirection(dct['itk_direction'])
sitk.WriteImage(seg_resized_itk, out_fname)
sys.stdout = sys.__stdout__
================================================
FILE: unetr_pp/inference_acdc.py
================================================
import glob
import os
import SimpleITK as sitk
import numpy as np
from medpy.metric import binary
from sklearn.neighbors import KDTree
from scipy import ndimage
import argparse
def read_nii(path):
itk_img=sitk.ReadImage(path)
spacing=np.array(itk_img.GetSpacing())
return sitk.GetArrayFromImage(itk_img),spacing
def dice(pred, label):
if (pred.sum() + label.sum()) == 0:
return 1
else:
return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum())
def process_label(label):
rv = label == 1
myo = label == 2
lv = label == 3
return rv,myo,lv
'''
def hd(pred,gt):
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = binary.dc(pred, gt)
hd95 = binary.hd95(pred, gt)
return dice, hd95
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0
else:
return 0, 0
'''
def hd(pred,gt):
#labelPred=sitk.GetImageFromArray(lP.astype(np.float32), isVector=False)
#labelTrue=sitk.GetImageFromArray(lT.astype(np.float32), isVector=False)
#hausdorffcomputer=sitk.HausdorffDistanceImageFilter()
#hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5)
#return hausdorffcomputer.GetAverageHausdorffDistance()
if pred.sum() > 0 and gt.sum()>0:
hd95 = binary.hd95(pred, gt)
print(hd95)
return hd95
else:
return 0
def test(fold):
label_path = './'
pred_path = '/'
label_list = sorted(glob.glob(os.path.join(label_path, '*nii.gz')))
infer_list = sorted(glob.glob(os.path.join(pred_path, '*nii.gz')))
print("loading success...")
print(label_list)
print(infer_list)
Dice_rv=[]
Dice_myo=[]
Dice_lv=[]
hd_rv=[]
hd_myo=[]
hd_lv=[]
file=path + 'inferTs/'+fold
if not os.path.exists(file):
os.makedirs(file)
fw = open(file+'/dice_pre.txt', 'w')
for label_path,infer_path in zip(label_list,infer_list):
print(label_path.split('/')[-1])
print(infer_path.split('/')[-1])
label,spacing= read_nii(label_path)
infer,spacing= read_nii(infer_path)
label_rv,label_myo,label_lv=process_label(label)
infer_rv,infer_myo,infer_lv=process_label(infer)
Dice_rv.append(dice(infer_rv,label_rv))
Dice_myo.append(dice(infer_myo,label_myo))
Dice_lv.append(dice(infer_lv,label_lv))
hd_rv.append(hd(infer_rv,label_rv))
hd_myo.append(hd(infer_myo,label_myo))
hd_lv.append(hd(infer_lv,label_lv))
fw.write('*'*20+'\n',)
fw.write(infer_path.split('/')[-1]+'\n')
fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1]))
fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1]))
fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1]))
#fw.write('*'*20+'\n')
fw.write('*'*20+'\n',)
fw.write(infer_path.split('/')[-1]+'\n')
fw.write('Dice_rv: {:.4f}\n'.format(Dice_rv[-1]))
fw.write('Dice_myo: {:.4f}\n'.format(Dice_myo[-1]))
fw.write('Dice_lv: {:.4f}\n'.format(Dice_lv[-1]))
fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1]))
fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1]))
fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1]))
fw.write('*'*20+'\n')
#fw.write('*'*20+'\n')
#fw.write('Mean_hd\n')
#fw.write('hd_rv'+str(np.mean(hd_rv))+'\n')
#fw.write('hd_myo'+str(np.mean(hd_myo))+'\n')
#fw.write('hd_lv'+str(np.mean(hd_lv))+'\n')
#fw.write('*'*20+'\n')
fw.write('*'*20+'\n')
fw.write('Mean_Dice\n')
fw.write('Dice_rv'+str(np.mean(Dice_rv))+'\n')
fw.write('Dice_myo'+str(np.mean(Dice_myo))+'\n')
fw.write('Dice_lv'+str(np.mean(Dice_lv))+'\n')
fw.write('Mean_HD\n')
fw.write('HD_rv'+str(np.mean(hd_rv))+'\n')
fw.write('HD_myo'+str(np.mean(hd_myo))+'\n')
fw.write('HD_lv'+str(np.mean(hd_lv))+'\n')
fw.write('*'*20+'\n')
dsc=[]
dsc.append(np.mean(Dice_rv))
dsc.append(np.mean(Dice_myo))
dsc.append(np.mean(Dice_lv))
avg_hd=[]
avg_hd.append(np.mean(hd_rv))
avg_hd.append(np.mean(hd_myo))
avg_hd.append(np.mean(hd_lv))
fw.write('avg_hd:'+str(np.mean(avg_hd))+'\n')
fw.write('DSC:'+str(np.mean(dsc))+'\n')
fw.write('HD:'+str(np.mean(avg_hd))+'\n')
print('done')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("fold", help="fold name")
args = parser.parse_args()
fold = args.fold
test(fold)
================================================
FILE: unetr_pp/inference_synapse.py
================================================
import glob
import os
import SimpleITK as sitk
import numpy as np
import argparse
from medpy import metric
def read_nii(path):
return sitk.GetArrayFromImage(sitk.ReadImage(path))
def dice(pred, label):
if (pred.sum() + label.sum()) == 0:
return 1
else:
return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum())
def hd(pred,gt):
if pred.sum() > 0 and gt.sum()>0:
hd95 = metric.binary.hd95(pred, gt)
return hd95
else:
return 0
def process_label(label):
spleen = label == 1
right_kidney = label == 2
left_kidney = label == 3
gallbladder = label == 4
liver = label == 6
stomach = label == 7
aorta = label == 8
pancreas = label == 11
return spleen,right_kidney,left_kidney,gallbladder,liver,stomach,aorta,pancreas
def test(fold):
label_path= None # Replace None by full path of "DATASET/unetr_pp_raw/unetr_pp_raw_data/Task002_Synapse/"
infer_path = None # Replace None by full path of "output_synapse"
label_list=sorted(glob.glob(os.path.join(label_path,'labelsTs','*nii.gz')))
infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz')))
print("loading success...")
print(label_list)
print(infer_list)
Dice_spleen=[]
Dice_right_kidney=[]
Dice_left_kidney=[]
Dice_gallbladder=[]
Dice_liver=[]
Dice_stomach=[]
Dice_aorta=[]
Dice_pancreas=[]
hd_spleen=[]
hd_right_kidney=[]
hd_left_kidney=[]
hd_gallbladder=[]
hd_liver=[]
hd_stomach=[]
hd_aorta=[]
hd_pancreas=[]
file=infer_path + 'inferTs/'+fold
if not os.path.exists(file):
os.makedirs(file)
fw = open(file+'/dice_pre.txt', 'a')
for label_path,infer_path in zip(label_list,infer_list):
print(label_path.split('/')[-1])
print(infer_path.split('/')[-1])
label,infer = read_nii(label_path),read_nii(infer_path)
label_spleen,label_right_kidney,label_left_kidney,label_gallbladder,label_liver,label_stomach,label_aorta,label_pancreas=process_label(label)
infer_spleen,infer_right_kidney,infer_left_kidney,infer_gallbladder,infer_liver,infer_stomach,infer_aorta,infer_pancreas=process_label(infer)
Dice_spleen.append(dice(infer_spleen,label_spleen))
Dice_right_kidney.append(dice(infer_right_kidney,label_right_kidney))
Dice_left_kidney.append(dice(infer_left_kidney,label_left_kidney))
Dice_gallbladder.append(dice(infer_gallbladder,label_gallbladder))
Dice_liver.append(dice(infer_liver,label_liver))
Dice_stomach.append(dice(infer_stomach,label_stomach))
Dice_aorta.append(dice(infer_aorta,label_aorta))
Dice_pancreas.append(dice(infer_pancreas,label_pancreas))
hd_spleen.append(hd(infer_spleen,label_spleen))
hd_right_kidney.append(hd(infer_right_kidney,label_right_kidney))
hd_left_kidney.append(hd(infer_left_kidney,label_left_kidney))
hd_gallbladder.append(hd(infer_gallbladder,label_gallbladder))
hd_liver.append(hd(infer_liver,label_liver))
hd_stomach.append(hd(infer_stomach,label_stomach))
hd_aorta.append(hd(infer_aorta,label_aorta))
hd_pancreas.append(hd(infer_pancreas,label_pancreas))
fw.write('*'*20+'\n',)
fw.write(infer_path.split('/')[-1]+'\n')
fw.write('Dice_spleen: {:.4f}\n'.format(Dice_spleen[-1]))
fw.write('Dice_right_kidney: {:.4f}\n'.format(Dice_right_kidney[-1]))
fw.write('Dice_left_kidney: {:.4f}\n'.format(Dice_left_kidney[-1]))
fw.write('Dice_gallbladder: {:.4f}\n'.format(Dice_gallbladder[-1]))
fw.write('Dice_liver: {:.4f}\n'.format(Dice_liver[-1]))
fw.write('Dice_stomach: {:.4f}\n'.format(Dice_stomach[-1]))
fw.write('Dice_aorta: {:.4f}\n'.format(Dice_aorta[-1]))
fw.write('Dice_pancreas: {:.4f}\n'.format(Dice_pancreas[-1]))
fw.write('hd_spleen: {:.4f}\n'.format(hd_spleen[-1]))
fw.write('hd_right_kidney: {:.4f}\n'.format(hd_right_kidney[-1]))
fw.write('hd_left_kidney: {:.4f}\n'.format(hd_left_kidney[-1]))
fw.write('hd_gallbladder: {:.4f}\n'.format(hd_gallbladder[-1]))
fw.write('hd_liver: {:.4f}\n'.format(hd_liver[-1]))
fw.write('hd_stomach: {:.4f}\n'.format(hd_stomach[-1]))
fw.write('hd_aorta: {:.4f}\n'.format(hd_aorta[-1]))
fw.write('hd_pancreas: {:.4f}\n'.format(hd_pancreas[-1]))
dsc=[]
HD=[]
dsc.append(Dice_spleen[-1])
dsc.append((Dice_right_kidney[-1]))
dsc.append(Dice_left_kidney[-1])
dsc.append(np.mean(Dice_gallbladder[-1]))
dsc.append(np.mean(Dice_liver[-1]))
dsc.append(np.mean(Dice_stomach[-1]))
dsc.append(np.mean(Dice_aorta[-1]))
dsc.append(np.mean(Dice_pancreas[-1]))
fw.write('DSC:'+str(np.mean(dsc))+'\n')
HD.append(hd_spleen[-1])
HD.append(hd_right_kidney[-1])
HD.append(hd_left_kidney[-1])
HD.append(hd_gallbladder[-1])
HD.append(hd_liver[-1])
HD.append(hd_stomach[-1])
HD.append(hd_aorta[-1])
HD.append(hd_pancreas[-1])
fw.write('hd:'+str(np.mean(HD))+'\n')
fw.write('*'*20+'\n')
fw.write('Mean_Dice\n')
fw.write('Dice_spleen'+str(np.mean(Dice_spleen))+'\n')
fw.write('Dice_right_kidney'+str(np.mean(Dice_right_kidney))+'\n')
fw.write('Dice_left_kidney'+str(np.mean(Dice_left_kidney))+'\n')
fw.write('Dice_gallbladder'+str(np.mean(Dice_gallbladder))+'\n')
fw.write('Dice_liver'+str(np.mean(Dice_liver))+'\n')
fw.write('Dice_stomach'+str(np.mean(Dice_stomach))+'\n')
fw.write('Dice_aorta'+str(np.mean(Dice_aorta))+'\n')
fw.write('Dice_pancreas'+str(np.mean(Dice_pancreas))+'\n')
fw.write('Mean_hd\n')
fw.write('hd_spleen'+str(np.mean(hd_spleen))+'\n')
fw.write('hd_right_kidney'+str(np.mean(hd_right_kidney))+'\n')
fw.write('hd_left_kidney'+str(np.mean(hd_left_kidney))+'\n')
fw.write('hd_gallbladder'+str(np.mean(hd_gallbladder))+'\n')
fw.write('hd_liver'+str(np.mean(hd_liver))+'\n')
fw.write('hd_stomach'+str(np.mean(hd_stomach))+'\n')
fw.write('hd_aorta'+str(np.mean(hd_aorta))+'\n')
fw.write('hd_pancreas'+str(np.mean(hd_pancreas))+'\n')
fw.write('*'*20+'\n')
dsc=[]
dsc.append(np.mean(Dice_spleen))
dsc.append(np.mean(Dice_right_kidney))
dsc.append(np.mean(Dice_left_kidney))
dsc.append(np.mean(Dice_gallbladder))
dsc.append(np.mean(Dice_liver))
dsc.append(np.mean(Dice_stomach))
dsc.append(np.mean(Dice_aorta))
dsc.append(np.mean(Dice_pancreas))
fw.write('dsc:'+str(np.mean(dsc))+'\n')
HD=[]
HD.append(np.mean(hd_spleen))
HD.append(np.mean(hd_right_kidney))
HD.append(np.mean(hd_left_kidney))
HD.append(np.mean(hd_gallbladder))
HD.append(np.mean(hd_liver))
HD.append(np.mean(hd_stomach))
HD.append(np.mean(hd_aorta))
HD.append(np.mean(hd_pancreas))
fw.write('hd:'+str(np.mean(HD))+'\n')
print('done')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("fold", help="fold name")
args = parser.parse_args()
fold=args.fold
test(fold)
================================================
FILE: unetr_pp/inference_tumor.py
================================================
import glob
import os
import SimpleITK as sitk
import numpy as np
import argparse
from medpy.metric import binary
def read_nii(path):
return sitk.GetArrayFromImage(sitk.ReadImage(path))
def new_dice(pred,label):
tp_hard = np.sum((pred == 1).astype(np.float) * (label == 1).astype(np.float))
fp_hard = np.sum((pred == 1).astype(np.float) * (label != 1).astype(np.float))
fn_hard = np.sum((pred != 1).astype(np.float) * (label == 1).astype(np.float))
return 2*tp_hard/(2*tp_hard+fp_hard+fn_hard)
def dice(pred, label):
if (pred.sum() + label.sum()) == 0:
return 1
else:
return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum())
def hd(pred,gt):
if pred.sum() > 0 and gt.sum()>0:
hd95 = binary.hd95(pred, gt)
return hd95
else:
return 0
def process_label(label):
net = label == 2
ed = label == 1
et = label == 3
ET=et
TC=net+et
WT=net+et+ed
ED= ed
NET=net
return ET,TC,WT,ED,NET
def test(fold):
#path='./'
path = None # Replace None by the full path of : unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/"
label_list=sorted(glob.glob(os.path.join(path,'labelsTs','*nii.gz')))
infer_path = None # Replace None by the full path of : unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/"
infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz')))
print("loading success...")
Dice_et=[]
Dice_tc=[]
Dice_wt=[]
Dice_ed=[]
Dice_net=[]
HD_et=[]
HD_tc=[]
HD_wt=[]
HD_ed=[]
HD_net=[]
file=infer_path + 'inferTs/'+fold
if not os.path.exists(file):
os.makedirs(file)
fw = open(file+'/dice_five.txt', 'w')
for label_path,infer_path in zip(label_list,infer_list):
print(label_path.split('/')[-1])
print(infer_path.split('/')[-1])
label,infer = read_nii(label_path),read_nii(infer_path)
label_et,label_tc,label_wt,label_ed,label_net=process_label(label)
infer_et,infer_tc,infer_wt,infer_ed,infer_net=process_label(infer)
Dice_et.append(dice(infer_et,label_et))
Dice_tc.append(dice(infer_tc,label_tc))
Dice_wt.append(dice(infer_wt,label_wt))
Dice_ed.append(dice(infer_ed,label_ed))
Dice_net.append(dice(infer_net,label_net))
HD_et.append(hd(infer_et,label_et))
HD_tc.append(hd(infer_tc,label_tc))
HD_wt.append(hd(infer_wt,label_wt))
HD_ed.append(hd(infer_ed,label_ed))
HD_net.append(hd(infer_net,label_net))
fw.write('*'*20+'\n',)
fw.write(infer_path.split('/')[-1]+'\n')
fw.write('hd_et: {:.4f}\n'.format(HD_et[-1]))
fw.write('hd_tc: {:.4f}\n'.format(HD_tc[-1]))
fw.write('hd_wt: {:.4f}\n'.format(HD_wt[-1]))
fw.write('hd_ed: {:.4f}\n'.format(HD_ed[-1]))
fw.write('hd_net: {:.4f}\n'.format(HD_net[-1]))
fw.write('*'*20+'\n',)
fw.write('Dice_et: {:.4f}\n'.format(Dice_et[-1]))
fw.write('Dice_tc: {:.4f}\n'.format(Dice_tc[-1]))
fw.write('Dice_wt: {:.4f}\n'.format(Dice_wt[-1]))
fw.write('Dice_ed: {:.4f}\n'.format(Dice_ed[-1]))
fw.write('Dice_net: {:.4f}\n'.format(Dice_net[-1]))
#print('dice_et: {:.4f}'.format(np.mean(Dice_et)))
#print('dice_tc: {:.4f}'.format(np.mean(Dice_tc)))
#print('dice_wt: {:.4f}'.format(np.mean(Dice_wt)))
dsc=[]
avg_hd=[]
dsc.append(np.mean(Dice_et))
dsc.append(np.mean(Dice_tc))
dsc.append(np.mean(Dice_wt))
dsc.append(np.mean(Dice_ed))
dsc.append(np.mean(Dice_net))
avg_hd.append(np.mean(HD_et))
avg_hd.append(np.mean(HD_tc))
avg_hd.append(np.mean(HD_wt))
avg_hd.append(np.mean(HD_ed))
avg_hd.append(np.mean(HD_net))
fw.write('Dice_et'+str(np.mean(Dice_et))+' '+'\n')
fw.write('Dice_tc'+str(np.mean(Dice_tc))+' '+'\n')
fw.write('Dice_wt'+str(np.mean(Dice_wt))+' '+'\n')
fw.write('Dice_ed'+str(np.mean(Dice_ed))+' '+'\n')
fw.write('Dice_net'+str(np.mean(Dice_net))+' '+'\n')
fw.write('HD_et'+str(np.mean(HD_et))+' '+'\n')
fw.write('HD_tc'+str(np.mean(HD_tc))+' '+'\n')
fw.write('HD_wt'+str(np.mean(HD_wt))+' '+'\n')
fw.write('HD_ed'+str(np.mean(HD_ed))+' '+'\n')
fw.write('HD_net'+str(np.mean(HD_net))+' '+'\n')
fw.write('Dice'+str(np.mean(dsc))+' '+'\n')
fw.write('HD'+str(np.mean(avg_hd))+' '+'\n')
#print('Dice'+str(np.mean(dsc))+' '+'\n')
#print('HD'+str(np.mean(avg_hd))+' '+'\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("fold", help="fold name")
args = parser.parse_args()
fold=args.fold
test(fold)
================================================
FILE: unetr_pp/network_architecture/README.md
================================================
You can change batch size, input data size
```
https://github.com/282857341/nnFormer/blob/6e36d76f9b7d0bea522e1cd05adf502ba85480e6/nnformer/run/default_configuration.py#L49-L68
```
================================================
FILE: unetr_pp/network_architecture/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/network_architecture/acdc/__init__.py
================================================
================================================
FILE: unetr_pp/network_architecture/acdc/model_components.py
================================================
from torch import nn
from timm.models.layers import trunc_normal_
from typing import Sequence, Tuple, Union
from monai.networks.layers.utils import get_norm_layer
from monai.utils import optional_import
from unetr_pp.network_architecture.layers import LayerNorm
from unetr_pp.network_architecture.acdc.transformerblock import TransformerBlock
from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock
einops, _ = optional_import("einops")
class UnetrPPEncoder(nn.Module):
def __init__(self, input_size=[16 * 40 * 40, 8 * 20 * 20, 4 * 10 * 10, 2 * 5 * 5],dims=[32, 64, 128, 256],
proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1,
dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem_layer = nn.Sequential(
get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(1, 4, 4), stride=(1, 4, 4),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]),
)
self.downsample_layers.append(stem_layer)
for i in range(3):
downsample_layer = nn.Sequential(
get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i],
proj_size=proj_size[i], num_heads=num_heads,
dropout_rate=transformer_dropout_rate, pos_embed=True))
self.stages.append(nn.Sequential(*stage_blocks))
self.hidden_states = []
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (LayerNorm, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
hidden_states = []
x = self.downsample_layers[0](x)
x = self.stages[0](x)
hidden_states.append(x)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if i == 3: # Reshape the output of the last stage
x = einops.rearrange(x, "b c h w d -> b (h w d) c")
hidden_states.append(x)
return x, hidden_states
def forward(self, x):
x, hidden_states = self.forward_features(x)
return x, hidden_states
class UnetrUpBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
proj_size: int = 64,
num_heads: int = 4,
out_size: int = 0,
depth: int = 3,
conv_decoder: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of heads inside each EPA module.
out_size: spatial size for each decoder.
depth: number of blocks for the current decoder stage.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
# 4 feature resolution stages, each consisting of multiple residual blocks
self.decoder_block = nn.ModuleList()
# If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block
# (see suppl. material in the paper)
if conv_decoder == True:
self.decoder_block.append(
UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1,
norm_name=norm_name, ))
else:
stage_blocks = []
for j in range(depth):
stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels,
proj_size=proj_size, num_heads=num_heads,
dropout_rate=0.1, pos_embed=True))
self.decoder_block.append(nn.Sequential(*stage_blocks))
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, inp, skip):
out = self.transp_conv(inp)
out = out + skip
out = self.decoder_block[0](out)
return out
================================================
FILE: unetr_pp/network_architecture/acdc/transformerblock.py
================================================
import torch.nn as nn
import torch
from unetr_pp.network_architecture.dynunet_block import UnetResBlock
class TransformerBlock(nn.Module):
"""
A transformer block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
input_size: int,
hidden_size: int,
proj_size: int,
num_heads: int,
dropout_rate: float = 0.0,
pos_embed=False,
) -> None:
"""
Args:
input_size: the size of the input for each stage.
hidden_size: dimension of hidden layer.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
pos_embed: bool argument to determine if positional embedding is used.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
print("Hidden size is ", hidden_size)
print("Num heads is ", num_heads)
raise ValueError("hidden_size should be divisible by num_heads.")
self.norm = nn.LayerNorm(hidden_size)
self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True)
self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads,
channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate)
self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1))
self.pos_embed = None
if pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size))
def forward(self, x):
B, C, H, W, D = x.shape
x = x.reshape(B, C, H * W * D).permute(0, 2, 1)
if self.pos_embed is not None:
x = x + self.pos_embed
attn = x + self.gamma * self.epa_block(self.norm(x))
attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D)
attn = self.conv51(attn_skip)
attn = self.conv52(attn)
x = attn_skip + self.conv8(attn)
return x
class EPA(nn.Module):
"""
Efficient Paired Attention Block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)
# E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension
self.E = nn.Linear(input_size, proj_size)
self.F = nn.Linear(input_size, proj_size)
self.attn_drop = nn.Dropout(channel_attn_drop)
self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
def forward(self, x):
B, N, C = x.shape
qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
qkvv = qkvv.permute(2, 0, 3, 1, 4)
q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]
q_shared = q_shared.transpose(-2, -1)
k_shared = k_shared.transpose(-2, -1)
v_CA = v_CA.transpose(-2, -1)
v_SA = v_SA.transpose(-2, -1)
k_shared_projected = self.E(k_shared)
v_SA_projected = self.F(v_SA)
q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
k_shared = torch.nn.functional.normalize(k_shared, dim=-1)
attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
attn_CA = attn_CA.softmax(dim=-1)
attn_CA = self.attn_drop(attn_CA)
x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)
attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
attn_SA = attn_SA.softmax(dim=-1)
attn_SA = self.attn_drop_2(attn_SA)
x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)
# Concat fusion
x_SA = self.out_proj(x_SA)
x_CA = self.out_proj2(x_CA)
x = torch.cat((x_SA, x_CA), dim=-1)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature', 'temperature2'}
================================================
FILE: unetr_pp/network_architecture/acdc/unetr_pp_acdc.py
================================================
from torch import nn
from typing import Tuple, Union
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock
from unetr_pp.network_architecture.acdc.model_components import UnetrPPEncoder, UnetrUpBlock
class UNETR_PP(SegmentationNetwork):
"""
UNETR++ based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
feature_size: int = 16,
hidden_size: int = 256,
num_heads: int = 4,
pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
depths=None,
dims=None,
conv_op=nn.Conv3d,
do_ds=True,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimensions of the last encoder.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
dropout_rate: faction of the input units to drop.
depths: number of blocks for each stage.
dims: number of channel maps for the stages.
conv_op: type of convolution operation.
do_ds: use deep supervision to compute the loss.
"""
super().__init__()
if depths is None:
depths = [3, 3, 3, 3]
self.do_ds = do_ds
self.conv_op = conv_op
self.num_classes = out_channels
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.feat_size = (2, 5, 5,)
self.hidden_size = hidden_size
self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)
self.encoder1 = UnetResBlock(
spatial_dims=3,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=4*10*10,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=8*20*20,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=16*40*40,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=(1, 4, 4),
norm_name=norm_name,
out_size=16*160*160,
conv_decoder=True,
)
self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
if self.do_ds:
self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def forward(self, x_in):
x_output, hidden_states = self.unetr_pp_encoder(x_in)
convBlock = self.encoder1(x_in)
# Four encoders
enc1 = hidden_states[0]
enc2 = hidden_states[1]
enc3 = hidden_states[2]
enc4 = hidden_states[3]
# Four decoders
dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc3)
dec2 = self.decoder4(dec3, enc2)
dec1 = self.decoder3(dec2, enc1)
out = self.decoder2(dec1, convBlock)
if self.do_ds:
logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
else:
logits = self.out1(out)
return logits
================================================
FILE: unetr_pp/network_architecture/dynunet_block.py
================================================
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.utils import get_act_layer, get_norm_layer
class UnetResBlock(nn.Module):
"""
A skip-connection based module that can be used for DynUNet, based on:
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_.
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_.
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
):
super().__init__()
self.conv1 = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dropout=dropout,
conv_only=True,
)
self.conv2 = get_conv_layer(
spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True
)
self.lrelu = get_act_layer(name=act_name)
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.downsample = in_channels != out_channels
stride_np = np.atleast_1d(stride)
if not np.all(stride_np == 1):
self.downsample = True
if self.downsample:
self.conv3 = get_conv_layer(
spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True
)
self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
def forward(self, inp):
residual = inp
out = self.conv1(inp)
out = self.norm1(out)
out = self.lrelu(out)
out = self.conv2(out)
out = self.norm2(out)
if hasattr(self, "conv3"):
residual = self.conv3(residual)
if hasattr(self, "norm3"):
residual = self.norm3(residual)
out += residual
out = self.lrelu(out)
return out
class UnetBasicBlock(nn.Module):
"""
A CNN module module that can be used for DynUNet, based on:
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_.
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_.
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
):
super().__init__()
self.conv1 = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dropout=dropout,
conv_only=True,
)
self.conv2 = get_conv_layer(
spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True
)
self.lrelu = get_act_layer(name=act_name)
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
def forward(self, inp):
out = self.conv1(inp)
out = self.norm1(out)
out = self.lrelu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.lrelu(out)
return out
class UnetUpBlock(nn.Module):
"""
An upsampling module that can be used for DynUNet, based on:
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_.
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_.
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
trans_bias: transposed convolution bias.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
dropout: Optional[Union[Tuple, str, float]] = None,
trans_bias: bool = False,
):
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
dropout=dropout,
bias=trans_bias,
conv_only=True,
is_transposed=True,
)
self.conv_block = UnetBasicBlock(
spatial_dims,
out_channels + out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dropout=dropout,
norm_name=norm_name,
act_name=act_name,
)
def forward(self, inp, skip):
# number of channels for skip should equals to out_channels
out = self.transp_conv(inp)
out = torch.cat((out, skip), dim=1)
out = self.conv_block(out)
return out
class UnetOutBlock(nn.Module):
def __init__(
self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None
):
super().__init__()
self.conv = get_conv_layer(
spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True
)
def forward(self, inp):
return self.conv(inp)
def get_conv_layer(
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int] = 3,
stride: Union[Sequence[int], int] = 1,
act: Optional[Union[Tuple, str]] = Act.PRELU,
norm: Union[Tuple, str] = Norm.INSTANCE,
dropout: Optional[Union[Tuple, str, float]] = None,
bias: bool = False,
conv_only: bool = True,
is_transposed: bool = False,
):
padding = get_padding(kernel_size, stride)
output_padding = None
if is_transposed:
output_padding = get_output_padding(kernel_size, stride, padding)
return Convolution(
spatial_dims,
in_channels,
out_channels,
strides=stride,
kernel_size=kernel_size,
act=act,
norm=norm,
dropout=dropout,
bias=bias,
conv_only=conv_only,
is_transposed=is_transposed,
padding=padding,
output_padding=output_padding,
)
def get_padding(
kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]
) -> Union[Tuple[int, ...], int]:
kernel_size_np = np.atleast_1d(kernel_size)
stride_np = np.atleast_1d(stride)
padding_np = (kernel_size_np - stride_np + 1) / 2
if np.min(padding_np) < 0:
raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.")
padding = tuple(int(p) for p in padding_np)
return padding if len(padding) > 1 else padding[0]
def get_output_padding(
kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int]
) -> Union[Tuple[int, ...], int]:
kernel_size_np = np.atleast_1d(kernel_size)
stride_np = np.atleast_1d(stride)
padding_np = np.atleast_1d(padding)
out_padding_np = 2 * padding_np + stride_np - kernel_size_np
if np.min(out_padding_np) < 0:
raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.")
out_padding = tuple(int(p) for p in out_padding_np)
return out_padding if len(out_padding) > 1 else out_padding[0]
================================================
FILE: unetr_pp/network_architecture/generic_UNet.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from unetr_pp.utilities.nd_softmax import softmax_helper
from torch import nn
import torch
import numpy as np
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
import torch.nn.functional
class ConvDropoutNormNonlin(nn.Module):
"""
fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
"""
def __init__(self, input_channels, output_channels,
conv_op=nn.Conv2d, conv_kwargs=None,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None):
super(ConvDropoutNormNonlin, self).__init__()
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
if conv_kwargs is None:
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
self.nonlin_kwargs = nonlin_kwargs
self.nonlin = nonlin
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.conv_kwargs = conv_kwargs
self.conv_op = conv_op
self.norm_op = norm_op
self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
'p'] > 0:
self.dropout = self.dropout_op(**self.dropout_op_kwargs)
else:
self.dropout = None
self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
self.lrelu = self.nonlin(**self.nonlin_kwargs)
def forward(self, x):
x = self.conv(x)
if self.dropout is not None:
x = self.dropout(x)
return self.lrelu(self.instnorm(x))
class ConvDropoutNonlinNorm(ConvDropoutNormNonlin):
def forward(self, x):
x = self.conv(x)
if self.dropout is not None:
x = self.dropout(x)
return self.instnorm(self.lrelu(x))
class StackedConvLayers(nn.Module):
def __init__(self, input_feature_channels, output_feature_channels, num_convs,
conv_op=nn.Conv2d, conv_kwargs=None,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin):
'''
stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers
:param input_feature_channels:
:param output_feature_channels:
:param num_convs:
:param dilation:
:param kernel_size:
:param padding:
:param dropout:
:param initial_stride:
:param conv_op:
:param norm_op:
:param dropout_op:
:param inplace:
:param neg_slope:
:param norm_affine:
:param conv_bias:
'''
self.input_channels = input_feature_channels
self.output_channels = output_feature_channels
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
if conv_kwargs is None:
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
self.nonlin_kwargs = nonlin_kwargs
self.nonlin = nonlin
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.conv_kwargs = conv_kwargs
self.conv_op = conv_op
self.norm_op = norm_op
if first_stride is not None:
self.conv_kwargs_first_conv = deepcopy(conv_kwargs)
self.conv_kwargs_first_conv['stride'] = first_stride
else:
self.conv_kwargs_first_conv = conv_kwargs
super(StackedConvLayers, self).__init__()
self.blocks = nn.Sequential(
*([basic_block(input_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs_first_conv,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs)] +
[basic_block(output_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)]))
def forward(self, x):
return self.blocks(x)
def print_module_training_status(module):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \
isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \
or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module,
nn.BatchNorm1d):
print(str(module), module.training)
class Upsample(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
super(Upsample, self).__init__()
self.align_corners = align_corners
self.mode = mode
self.scale_factor = scale_factor
self.size = size
def forward(self, x):
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
align_corners=self.align_corners)
class Generic_UNet(SegmentationNetwork):
DEFAULT_BATCH_SIZE_3D = 2
DEFAULT_PATCH_SIZE_3D = (64, 192, 160)
SPACING_FACTOR_BETWEEN_STAGES = 2
BASE_NUM_FEATURES_3D = 30
MAX_NUMPOOL_3D = 999
MAX_NUM_FILTERS_3D = 320
DEFAULT_PATCH_SIZE_2D = (256, 256)
BASE_NUM_FEATURES_2D = 30
DEFAULT_BATCH_SIZE_2D = 50
MAX_NUMPOOL_2D = 999
MAX_FILTERS_2D = 480
use_this_for_batch_size_computation_2D = 19739648
use_this_for_batch_size_computation_3D = 520000000 # 505789440
def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2,
feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False,
final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
conv_kernel_sizes=None,
upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
max_num_features=None, basic_block=ConvDropoutNormNonlin,
seg_output_use_bias=False):
"""
basically more flexible than v1, architecture is the same
Does this look complicated? Nah bro. Functionality > usability
This does everything you need, including world peace.
Questions? -> f.isensee@dkfz.de
"""
super(Generic_UNet, self).__init__()
self.convolutional_upsampling = convolutional_upsampling
self.convolutional_pooling = convolutional_pooling
self.upscale_logits = upscale_logits
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.weightInitializer = weightInitializer
self.conv_op = conv_op
self.norm_op = norm_op
self.dropout_op = dropout_op
self.num_classes = num_classes
self.final_nonlin = final_nonlin
self._deep_supervision = deep_supervision
self.do_ds = deep_supervision
if conv_op == nn.Conv2d:
upsample_mode = 'bilinear'
pool_op = nn.MaxPool2d
transpconv = nn.ConvTranspose2d
if pool_op_kernel_sizes is None:
pool_op_kernel_sizes = [(2, 2)] * num_pool
if conv_kernel_sizes is None:
conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
elif conv_op == nn.Conv3d:
upsample_mode = 'trilinear'
pool_op = nn.MaxPool3d
transpconv = nn.ConvTranspose3d
if pool_op_kernel_sizes is None:
pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
if conv_kernel_sizes is None:
conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
else:
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
self.pool_op_kernel_sizes = pool_op_kernel_sizes
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_pad_sizes = []
for krnl in self.conv_kernel_sizes:
self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
if max_num_features is None:
if self.conv_op == nn.Conv3d:
self.max_num_features = self.MAX_NUM_FILTERS_3D
else:
self.max_num_features = self.MAX_FILTERS_2D
else:
self.max_num_features = max_num_features
self.conv_blocks_context = []
self.conv_blocks_localization = []
self.td = []
self.tu = []
self.seg_outputs = []
output_features = base_num_features
input_features = input_channels
for d in range(num_pool):
# determine the first stride
if d != 0 and self.convolutional_pooling:
first_stride = pool_op_kernel_sizes[d - 1]
else:
first_stride = None
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
# add convolutions
self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
self.conv_op, self.conv_kwargs, self.norm_op,
self.norm_op_kwargs, self.dropout_op,
self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
first_stride, basic_block=basic_block))
if not self.convolutional_pooling:
self.td.append(pool_op(pool_op_kernel_sizes[d]))
input_features = output_features
output_features = int(np.round(output_features * feat_map_mul_on_downscale))
output_features = min(output_features, self.max_num_features)
# now the bottleneck.
# determine the first stride
if self.convolutional_pooling:
first_stride = pool_op_kernel_sizes[-1]
else:
first_stride = None
# the output of the last conv must match the number of features from the skip connection if we are not using
# convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
# done by the transposed conv
if self.convolutional_upsampling:
final_num_features = output_features
else:
final_num_features = self.conv_blocks_context[-1].output_channels
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
self.conv_blocks_context.append(nn.Sequential(
StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
self.nonlin_kwargs, first_stride, basic_block=basic_block),
StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
self.nonlin_kwargs, basic_block=basic_block)))
# if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
if not dropout_in_localization:
old_dropout_p = self.dropout_op_kwargs['p']
self.dropout_op_kwargs['p'] = 0.0
# now lets build the localization pathway
for u in range(num_pool):
nfeatures_from_down = final_num_features
nfeatures_from_skip = self.conv_blocks_context[
-(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
n_features_after_tu_and_concat = nfeatures_from_skip * 2
# the first conv reduces the number of features to match those of skip
# the following convs work on that number of features
# if not convolutional upsampling then the final conv reduces the num of features again
if u != num_pool - 1 and not self.convolutional_upsampling:
final_num_features = self.conv_blocks_context[-(3 + u)].output_channels
else:
final_num_features = nfeatures_from_skip
if not self.convolutional_upsampling:
self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode))
else:
self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)],
pool_op_kernel_sizes[-(u + 1)], bias=False))
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)]
self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)]
self.conv_blocks_localization.append(nn.Sequential(
StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block),
StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs, basic_block=basic_block)
))
for ds in range(len(self.conv_blocks_localization)):
self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes,
1, 1, 0, 1, 1, seg_output_use_bias))
self.upscale_logits_ops = []
cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1]
for usl in range(num_pool - 1):
if self.upscale_logits:
self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]),
mode=upsample_mode))
else:
self.upscale_logits_ops.append(lambda x: x)
if not dropout_in_localization:
self.dropout_op_kwargs['p'] = old_dropout_p
# register all modules properly
self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
self.td = nn.ModuleList(self.td)
self.tu = nn.ModuleList(self.tu)
self.seg_outputs = nn.ModuleList(self.seg_outputs)
if self.upscale_logits:
self.upscale_logits_ops = nn.ModuleList(
self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here
if self.weightInitializer is not None:
self.apply(self.weightInitializer)
# self.apply(print_module_training_status)
def forward(self, x):
# print(x.shape)
skips = []
seg_outputs = []
for d in range(len(self.conv_blocks_context) - 1):
x = self.conv_blocks_context[d](x)
skips.append(x)
if not self.convolutional_pooling:
x = self.td[d](x)
x = self.conv_blocks_context[-1](x)
for u in range(len(self.tu)):
x = self.tu[u](x)
x = torch.cat((x, skips[-(u + 1)]), dim=1)
x = self.conv_blocks_localization[u](x)
seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x)))
if self._deep_supervision and self.do_ds:
return tuple([seg_outputs[-1]] + [i(j) for i, j in
zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])])
else:
return seg_outputs[-1]
@staticmethod
def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features,
num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False,
conv_per_stage=2):
"""
This only applies for num_conv_per_stage and convolutional_upsampling=True
not real vram consumption. just a constant term to which the vram consumption will be approx proportional
(+ offset for parameter storage)
:param deep_supervision:
:param patch_size:
:param num_pool_per_axis:
:param base_num_features:
:param max_num_features:
:param num_modalities:
:param num_classes:
:param pool_op_kernel_sizes:
:return:
"""
if not isinstance(num_pool_per_axis, np.ndarray):
num_pool_per_axis = np.array(num_pool_per_axis)
npool = len(pool_op_kernel_sizes)
map_size = np.array(patch_size)
tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features +
num_modalities * np.prod(map_size, dtype=np.int64) +
num_classes * np.prod(map_size, dtype=np.int64))
num_feat = base_num_features
for p in range(npool):
for pi in range(len(num_pool_per_axis)):
map_size[pi] /= pool_op_kernel_sizes[p][pi]
num_feat = min(num_feat * 2, max_num_features)
num_blocks = (conv_per_stage * 2 + 1) if p < (
npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv
tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat
if deep_supervision and p < (npool - 2):
tmp += np.prod(map_size, dtype=np.int64) * num_classes
# print(p, map_size, num_feat, tmp)
return tmp
================================================
FILE: unetr_pp/network_architecture/initialization.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
class InitWeights_He(object):
def __init__(self, neg_slope=1e-2):
self.neg_slope = neg_slope
def __call__(self, module):
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
if module.bias is not None:
module.bias = nn.init.constant_(module.bias, 0)
class InitWeights_XavierUniform(object):
def __init__(self, gain=1):
self.gain = gain
def __call__(self, module):
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
module.weight = nn.init.xavier_uniform_(module.weight, self.gain)
if module.bias is not None:
module.bias = nn.init.constant_(module.bias, 0)
================================================
FILE: unetr_pp/network_architecture/layers.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
import math
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def forward(self, B, H, W):
mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos
================================================
FILE: unetr_pp/network_architecture/lung/__init__.py
================================================
================================================
FILE: unetr_pp/network_architecture/lung/model_components.py
================================================
from torch import nn
from timm.models.layers import trunc_normal_
from typing import Sequence, Tuple, Union
from monai.networks.layers.utils import get_norm_layer
from monai.utils import optional_import
from unetr_pp.network_architecture.layers import LayerNorm
from unetr_pp.network_architecture.lung.transformerblock import TransformerBlock
from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock
einops, _ = optional_import("einops")
#input_size=[28 * 24 * 20, 14 * 12 * 10, 7 * 6 * 5, ],dims=[32, 64, 128, 256]
class UnetrPPEncoder(nn.Module):
def __init__(self, input_size=[32*48*48, 16 * 24 * 24, 8 * 12 * 12, 4*6*6],dims=[32, 64, 128, 256],
proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1,
dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem_layer = nn.Sequential(
get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(1, 4, 4), stride=(1, 4, 4),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]),
)
self.downsample_layers.append(stem_layer)
for i in range(3):
downsample_layer = nn.Sequential(
get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
#print(" i ", i)
#print("inputtt size ",input_size[i])
stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i],
proj_size=proj_size[i], num_heads=num_heads,
dropout_rate=transformer_dropout_rate, pos_embed=True))
self.stages.append(nn.Sequential(*stage_blocks))
self.hidden_states = []
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (LayerNorm, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
hidden_states = []
x = self.downsample_layers[0](x)
x = self.stages[0](x)
hidden_states.append(x)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if i == 3: # Reshape the output of the last stage
x = einops.rearrange(x, "b c h w d -> b (h w d) c")
hidden_states.append(x)
return x, hidden_states
def forward(self, x):
x, hidden_states = self.forward_features(x)
return x, hidden_states
class UnetrUpBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
proj_size: int = 64,
num_heads: int = 4,
out_size: int = 0,
depth: int = 3,
conv_decoder: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of heads inside each EPA module.
out_size: spatial size for each decoder.
depth: number of blocks for the current decoder stage.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
# 4 feature resolution stages, each consisting of multiple residual blocks
self.decoder_block = nn.ModuleList()
# If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block
# (see suppl. material in the paper)
if conv_decoder == True:
self.decoder_block.append(
UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1,
norm_name=norm_name, ))
else:
stage_blocks = []
for j in range(depth):
stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels,
proj_size=proj_size, num_heads=num_heads,
dropout_rate=0.1, pos_embed=True))
self.decoder_block.append(nn.Sequential(*stage_blocks))
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, inp, skip):
out = self.transp_conv(inp)
out = out + skip
out = self.decoder_block[0](out)
return out
================================================
FILE: unetr_pp/network_architecture/lung/transformerblock.py
================================================
import torch.nn as nn
import torch
from unetr_pp.network_architecture.dynunet_block import UnetResBlock
class TransformerBlock(nn.Module):
"""
A transformer block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
input_size: int,
hidden_size: int,
proj_size: int,
num_heads: int,
dropout_rate: float = 0.0,
pos_embed=False,
) -> None:
"""
Args:
input_size: the size of the input for each stage.
hidden_size: dimension of hidden layer.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
pos_embed: bool argument to determine if positional embedding is used.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
print("Hidden size is ", hidden_size)
print("Num heads is ", num_heads)
raise ValueError("hidden_size should be divisible by num_heads.")
self.norm = nn.LayerNorm(hidden_size)
self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True)
self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads,
channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate)
self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1))
self.pos_embed = None
if pos_embed:
#print("input size", input_size)
self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size))
def forward(self, x):
B, C, H, W, D = x.shape
#print("XSHAPE ",x.shape)
x = x.reshape(B, C, H * W * D).permute(0, 2, 1)
if self.pos_embed is not None:
#print ("x",x.shape)
#print("pos_embed",self.pos_embed.shape)
x = x + self.pos_embed
attn = x + self.gamma * self.epa_block(self.norm(x))
attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D)
attn = self.conv51(attn_skip)
attn = self.conv52(attn)
x = attn_skip + self.conv8(attn)
return x
class EPA(nn.Module):
"""
Efficient Paired Attention Block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)
# E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension
self.E = nn.Linear(input_size, proj_size)
self.F = nn.Linear(input_size, proj_size)
self.attn_drop = nn.Dropout(channel_attn_drop)
self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
def forward(self, x):
B, N, C = x.shape
#print("The shape in EPA ", self.E.shape)
qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
qkvv = qkvv.permute(2, 0, 3, 1, 4)
q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]
q_shared = q_shared.transpose(-2, -1)
k_shared = k_shared.transpose(-2, -1)
v_CA = v_CA.transpose(-2, -1)
v_SA = v_SA.transpose(-2, -1)
k_shared_projected = self.E(k_shared)
v_SA_projected = self.F(v_SA)
q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
k_shared = torch.nn.functional.normalize(k_shared, dim=-1)
attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
attn_CA = attn_CA.softmax(dim=-1)
attn_CA = self.attn_drop(attn_CA)
x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)
attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
attn_SA = attn_SA.softmax(dim=-1)
attn_SA = self.attn_drop_2(attn_SA)
x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)
# Concat fusion
x_SA = self.out_proj(x_SA)
x_CA = self.out_proj2(x_CA)
x = torch.cat((x_SA, x_CA), dim=-1)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature', 'temperature2'}
================================================
FILE: unetr_pp/network_architecture/lung/unetr_pp_lung.py
================================================
from torch import nn
from typing import Tuple, Union
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock
from unetr_pp.network_architecture.lung.model_components import UnetrPPEncoder, UnetrUpBlock
class UNETR_PP(SegmentationNetwork):
"""
UNETR++ based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
feature_size: int = 16,
hidden_size: int = 256,
num_heads: int = 4,
pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
depths=None,
dims=None,
conv_op=nn.Conv3d,
do_ds=True,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimensions of the last encoder.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
dropout_rate: faction of the input units to drop.
depths: number of blocks for each stage.
dims: number of channel maps for the stages.
conv_op: type of convolution operation.
do_ds: use deep supervision to compute the loss.
"""
super().__init__()
if depths is None:
depths = [3, 3, 3, 3]
self.do_ds = do_ds
self.conv_op = conv_op
self.num_classes = out_channels
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.feat_size = (4, 6, 6,)
self.hidden_size = hidden_size
self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)
self.encoder1 = UnetResBlock(
spatial_dims=3,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=8*12*12,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=16*24*24,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=32*48*48,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=(1, 4, 4),
norm_name=norm_name,
out_size=32*192*192,
conv_decoder=True,
)
self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
if self.do_ds:
self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def forward(self, x_in):
#print("#####input_shape:", x_in.shape)
x_output, hidden_states = self.unetr_pp_encoder(x_in)
convBlock = self.encoder1(x_in)
# Four encoders
enc1 = hidden_states[0]
#print("ENC1:",enc1.shape)
enc2 = hidden_states[1]
#print("ENC2:",enc2.shape)
enc3 = hidden_states[2]
#print("ENC3:",enc3.shape)
enc4 = hidden_states[3]
#print("ENC4:",enc4.shape)
# Four decoders
dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc3)
dec2 = self.decoder4(dec3, enc2)
dec1 = self.decoder3(dec2, enc1)
out = self.decoder2(dec1, convBlock)
if self.do_ds:
logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
else:
logits = self.out1(out)
return logits
================================================
FILE: unetr_pp/network_architecture/neural_network.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from batchgenerators.augmentations.utils import pad_nd_image
from unetr_pp.utilities.random_stuff import no_op
from unetr_pp.utilities.to_torch import to_cuda, maybe_to_torch
from torch import nn
import torch
from scipy.ndimage.filters import gaussian_filter
from typing import Union, Tuple, List
from torch.cuda.amp import autocast
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
def get_device(self):
if next(self.parameters()).device == "cpu":
return "cpu"
else:
return next(self.parameters()).device.index
def set_device(self, device):
if device == "cpu":
self.cpu()
else:
self.cuda(device)
def forward(self, x):
raise NotImplementedError
class SegmentationNetwork(NeuralNetwork):
def __init__(self):
super(NeuralNetwork, self).__init__()
# if we have 5 pooling then our patch size must be divisible by 2**5
self.input_shape_must_be_divisible_by = None # for example in a 2d network that does 5 pool in x and 6 pool
# in y this would be (32, 64)
# we need to know this because we need to know if we are a 2d or a 3d netowrk
self.conv_op = None # nn.Conv2d or nn.Conv3d
# this tells us how many channely we have in the output. Important for preallocation in inference
self.num_classes = None # number of channels in the output
# depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions
# during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what
# to apply in inference. For the most part this will be softmax
self.inference_apply_nonlin = lambda x: x # softmax_helper
# This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the
# center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians
# can be expensive, so it makes sense to save and reuse them.
self._gaussian_3d = self._patch_size_for_gaussian_3d = None
self._gaussian_2d = self._patch_size_for_gaussian_2d = None
def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2),
use_sliding_window: bool = False,
step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None,
use_gaussian: bool = False, pad_border_mode: str = "constant",
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
Use this function to predict a 3D image. It does not matter whether the network is a 2D or 3D U-Net, it will
detect that automatically and run the appropriate code.
When running predictions, you need to specify whether you want to run fully convolutional of sliding window
based inference. We very strongly recommend you use sliding window with the default settings.
It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If
the network is not in eval mode it will print a warning.
:param x: Your input data. Must be a nd.ndarray of shape (c, x, y, z).
:param do_mirroring: If True, use test time data augmentation in the form of mirroring
:param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three
axes
:param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default
:param step_size: When running sliding window prediction, the step size determines the distance between adjacent
predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given
as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between
predictions. step_size cannot be larger than 1!
:param patch_size: The patch size that was used for training the network. Do not use different patch sizes here,
this will either crash or give potentially less accurate segmentations
:param regions_class_order: Fabian only
:param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting
to weigh predictions closer to the center of the current patch higher than those at the borders. The reason
behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True
:param pad_border_mode: leave this alone
:param pad_kwargs: leave this alone
:param all_in_gpu: experimental. You probably want to leave this as is it
:param verbose: Do you want a wall of text? If yes then set this to True
:param mixed_precision: if True, will run inference in mixed precision with autocast()
:return:
"""
torch.cuda.empty_cache()
assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \
'predictions'
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
assert self.get_device() != "cpu", "CPU not implemented"
if pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
# A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old
# code that uses this convention
self.conv_op = nn.Conv3d
if len(mirror_axes):
if self.conv_op == nn.Conv2d:
if max(mirror_axes) > 1:
raise ValueError("mirror axes. duh")
if self.conv_op == nn.Conv3d:
if max(mirror_axes) > 2:
raise ValueError("mirror axes. duh")
if self.training:
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
assert len(x.shape) == 4, "data must have shape (c,x,y,z)"
if mixed_precision:
context = autocast
else:
context = no_op
with context():
with torch.no_grad():
if self.conv_op == nn.Conv3d:
if use_sliding_window:
res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose)
else:
res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose)
elif self.conv_op == nn.Conv2d:
if use_sliding_window:
res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs, all_in_gpu, False)
else:
res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs, all_in_gpu, False)
else:
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
return res
def predict_2D(self, x, do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), use_sliding_window: bool = False,
step_size: float = 0.5, patch_size: tuple = None, regions_class_order: tuple = None,
use_gaussian: bool = False, pad_border_mode: str = "constant",
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
Use this function to predict a 2D image. If this is a 3D U-Net it will crash because you cannot predict a 2D
image with that (you dummy).
When running predictions, you need to specify whether you want to run fully convolutional of sliding window
based inference. We very strongly recommend you use sliding window with the default settings.
It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If
the network is not in eval mode it will print a warning.
:param x: Your input data. Must be a nd.ndarray of shape (c, x, y).
:param do_mirroring: If True, use test time data augmentation in the form of mirroring
:param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three
axes
:param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default
:param step_size: When running sliding window prediction, the step size determines the distance between adjacent
predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given
as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between
predictions. step_size cannot be larger than 1!
:param patch_size: The patch size that was used for training the network. Do not use different patch sizes here,
this will either crash or give potentially less accurate segmentations
:param regions_class_order: Fabian only
:param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting
to weigh predictions closer to the center of the current patch higher than those at the borders. The reason
behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True
:param pad_border_mode: leave this alone
:param pad_kwargs: leave this alone
:param all_in_gpu: experimental. You probably want to leave this as is it
:param verbose: Do you want a wall of text? If yes then set this to True
:return:
"""
torch.cuda.empty_cache()
assert step_size <= 1, 'step_size must be smaler than 1. Otherwise there will be a gap between consecutive ' \
'predictions'
if self.conv_op == nn.Conv3d:
raise RuntimeError("Cannot predict 2d if the network is 3d. Dummy.")
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
assert self.get_device() != "cpu", "CPU not implemented"
if pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
# A very long time ago the mirror axes were (2, 3) for a 2d network. This is just to intercept any old
# code that uses this convention
if len(mirror_axes):
if max(mirror_axes) > 1:
raise ValueError("mirror axes. duh")
if self.training:
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
assert len(x.shape) == 3, "data must have shape (c,x,y)"
if mixed_precision:
context = autocast
else:
context = no_op
with context():
with torch.no_grad():
if self.conv_op == nn.Conv2d:
if use_sliding_window:
res = self._internal_predict_2D_2Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs, all_in_gpu, verbose)
else:
res = self._internal_predict_2D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs, verbose)
else:
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
return res
@staticmethod
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
tmp = np.zeros(patch_size)
center_coords = [i // 2 for i in patch_size]
sigmas = [i * sigma_scale for i in patch_size]
tmp[tuple(center_coords)] = 1
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
gaussian_importance_map = gaussian_importance_map.astype(np.float32)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
gaussian_importance_map[gaussian_importance_map != 0])
return gaussian_importance_map
@staticmethod
def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]:
assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size"
assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
target_step_sizes_in_voxels = [i * step_size for i in patch_size]
num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)]
steps = []
for dim in range(len(patch_size)):
# the highest step value for this dimension is
max_step_value = image_size[dim] - patch_size[dim]
if num_steps[dim] > 1:
actual_step_size = max_step_value / (num_steps[dim] - 1)
else:
actual_step_size = 99999999999 # does not matter because there is only one step at 0
steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
steps.append(steps_here)
return steps
def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
# better safe than sorry
assert len(x.shape) == 4, "x must be (c, x, y, z)"
assert self.get_device() != "cpu"
if verbose: print("step_size:", step_size)
if verbose: print("do mirror:", do_mirroring)
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
# for sliding window inference the image must at least be as large as the patch size. It does not matter
# whether the shape is divisible by 2**num_pool as long as the patch size is
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
data_shape = data.shape # still c, x, y, z
# compute the steps for sliding window
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2])
if verbose:
print("data shape:", data_shape)
print("patch size:", patch_size)
print("steps (x, y, and z):", steps)
print("number of tiles:", num_tiles)
# we only need to compute that once. It can take a while to compute this due to the large sigma in
# gaussian_filter
if use_gaussian and num_tiles > 1:
if self._gaussian_3d is None or not all(
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]):
if verbose: print('computing Gaussian')
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
self._gaussian_3d = gaussian_importance_map
self._patch_size_for_gaussian_3d = patch_size
else:
if verbose: print("using precomputed Gaussian")
gaussian_importance_map = self._gaussian_3d
gaussian_importance_map = torch.from_numpy(gaussian_importance_map).cuda(self.get_device(),
non_blocking=True)
else:
gaussian_importance_map = None
if all_in_gpu:
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
if use_gaussian and num_tiles > 1:
# half precision for the outputs should be good enough. If the outputs here are half, the
# gaussian_importance_map should be as well
gaussian_importance_map = gaussian_importance_map.half()
# make sure we did not round anything to 0
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
gaussian_importance_map != 0].min()
add_for_nb_of_preds = gaussian_importance_map
else:
add_for_nb_of_preds = torch.ones(data.shape[1:], device=self.get_device())
if verbose: print("initializing result array (on GPU)")
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
if verbose: print("moving data to GPU")
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
if verbose: print("initializing result_numsamples (on GPU)")
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
else:
if use_gaussian and num_tiles > 1:
add_for_nb_of_preds = self._gaussian_3d
else:
add_for_nb_of_preds = np.ones(data.shape[1:], dtype=np.float32)
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
for x in steps[0]:
lb_x = x
ub_x = x + patch_size[0]
for y in steps[1]:
lb_y = y
ub_y = y + patch_size[1]
for z in steps[2]:
lb_z = z
ub_z = z + patch_size[2]
predicted_patch = self._internal_maybe_mirror_and_pred_3D(
data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring,
gaussian_importance_map)[0]
if all_in_gpu:
predicted_patch = predicted_patch.half()
else:
predicted_patch = predicted_patch.cpu().numpy()
aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
slicer = tuple(
[slice(0, aggregated_results.shape[i]) for i in
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
aggregated_results = aggregated_results[slicer]
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
# computing the class_probabilities by dividing the aggregated result with result_numsamples
class_probabilities = aggregated_results / aggregated_nb_of_predictions
if regions_class_order is None:
predicted_segmentation = class_probabilities.argmax(0)
else:
if all_in_gpu:
class_probabilities_here = class_probabilities.detach().cpu().numpy()
else:
class_probabilities_here = class_probabilities
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
if all_in_gpu:
if verbose: print("copying results to CPU")
if regions_class_order is None:
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
class_probabilities = class_probabilities.detach().cpu().numpy()
if verbose: print("prediction done")
return predicted_segmentation, class_probabilities
def _internal_predict_2D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
This one does fully convolutional inference. No sliding window
"""
assert len(x.shape) == 3, "x must be (c, x, y)"
assert self.get_device() != "cpu"
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
'run _internal_predict_2D_2Dconv'
if verbose: print("do mirror:", do_mirroring)
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
self.input_shape_must_be_divisible_by)
predicted_probabilities = self._internal_maybe_mirror_and_pred_2D(data[None], mirror_axes, do_mirroring,
None)[0]
slicer = tuple(
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
(len(slicer) - 1))] + slicer[1:])
predicted_probabilities = predicted_probabilities[slicer]
if regions_class_order is None:
predicted_segmentation = predicted_probabilities.argmax(0)
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
else:
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
return predicted_segmentation, predicted_probabilities
def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...], do_mirroring: bool,
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
This one does fully convolutional inference. No sliding window
"""
assert len(x.shape) == 4, "x must be (c, x, y, z)"
assert self.get_device() != "cpu"
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
'run _internal_predict_3D_3Dconv'
if verbose: print("do mirror:", do_mirroring)
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
self.input_shape_must_be_divisible_by)
predicted_probabilities = self._internal_maybe_mirror_and_pred_3D(data[None], mirror_axes, do_mirroring,
None)[0]
slicer = tuple(
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
(len(slicer) - 1))] + slicer[1:])
predicted_probabilities = predicted_probabilities[slicer]
if regions_class_order is None:
predicted_segmentation = predicted_probabilities.argmax(0)
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
else:
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
return predicted_segmentation, predicted_probabilities
def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
# we now return a cuda tensor! Not numpy array!
x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device())
result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
dtype=torch.float).cuda(self.get_device(), non_blocking=True)
if mult is not None:
mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device())
if do_mirroring:
mirror_idx = 8
num_results = 2 ** len(mirror_axes)
else:
mirror_idx = 1
num_results = 1
for m in range(mirror_idx):
if m == 0:
pred = self.inference_apply_nonlin(self(x))
result_torch += 1 / num_results * pred
if m == 1 and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
result_torch += 1 / num_results * torch.flip(pred, (4,))
if m == 2 and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
result_torch += 1 / num_results * torch.flip(pred, (3,))
if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
result_torch += 1 / num_results * torch.flip(pred, (4, 3))
if m == 4 and (0 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
result_torch += 1 / num_results * torch.flip(pred, (2,))
if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
result_torch += 1 / num_results * torch.flip(pred, (4, 2))
if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))
if mult is not None:
result_torch[:, :] *= mult
return result_torch
def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
# we now return a cuda tensor! Not numpy array!
assert len(x.shape) == 4, 'x must be (b, c, x, y)'
x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device())
result_torch = torch.zeros([x.shape[0], self.num_classes] + list(x.shape[2:]),
dtype=torch.float).cuda(self.get_device(), non_blocking=True)
if mult is not None:
mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device())
if do_mirroring:
mirror_idx = 4
num_results = 2 ** len(mirror_axes)
else:
mirror_idx = 1
num_results = 1
for m in range(mirror_idx):
if m == 0:
pred = self.inference_apply_nonlin(self(x))
result_torch += 1 / num_results * pred
if m == 1 and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
result_torch += 1 / num_results * torch.flip(pred, (3, ))
if m == 2 and (0 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
result_torch += 1 / num_results * torch.flip(pred, (2, ))
if m == 3 and (0 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
if mult is not None:
result_torch[:, :] *= mult
return result_torch
def _internal_predict_2D_2Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
# better safe than sorry
assert len(x.shape) == 3, "x must be (c, x, y)"
assert self.get_device() != "cpu"
if verbose: print("step_size:", step_size)
if verbose: print("do mirror:", do_mirroring)
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
# for sliding window inference the image must at least be as large as the patch size. It does not matter
# whether the shape is divisible by 2**num_pool as long as the patch size is
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
data_shape = data.shape # still c, x, y
# compute the steps for sliding window
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
num_tiles = len(steps[0]) * len(steps[1])
if verbose:
print("data shape:", data_shape)
print("patch size:", patch_size)
print("steps (x, y, and z):", steps)
print("number of tiles:", num_tiles)
# we only need to compute that once. It can take a while to compute this due to the large sigma in
# gaussian_filter
if use_gaussian and num_tiles > 1:
if self._gaussian_2d is None or not all(
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_2d)]):
if verbose: print('computing Gaussian')
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
self._gaussian_2d = gaussian_importance_map
self._patch_size_for_gaussian_2d = patch_size
else:
if verbose: print("using precomputed Gaussian")
gaussian_importance_map = self._gaussian_2d
gaussian_importance_map = torch.from_numpy(gaussian_importance_map).cuda(self.get_device(),
non_blocking=True)
else:
gaussian_importance_map = None
if all_in_gpu:
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
if use_gaussian and num_tiles > 1:
# half precision for the outputs should be good enough. If the outputs here are half, the
# gaussian_importance_map should be as well
gaussian_importance_map = gaussian_importance_map.half()
# make sure we did not round anything to 0
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
gaussian_importance_map != 0].min()
add_for_nb_of_preds = gaussian_importance_map
else:
add_for_nb_of_preds = torch.ones(data.shape[1:], device=self.get_device())
if verbose: print("initializing result array (on GPU)")
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
if verbose: print("moving data to GPU")
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
if verbose: print("initializing result_numsamples (on GPU)")
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
else:
if use_gaussian and num_tiles > 1:
add_for_nb_of_preds = self._gaussian_2d
else:
add_for_nb_of_preds = np.ones(data.shape[1:], dtype=np.float32)
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
for x in steps[0]:
lb_x = x
ub_x = x + patch_size[0]
for y in steps[1]:
lb_y = y
ub_y = y + patch_size[1]
predicted_patch = self._internal_maybe_mirror_and_pred_2D(
data[None, :, lb_x:ub_x, lb_y:ub_y], mirror_axes, do_mirroring,
gaussian_importance_map)[0]
if all_in_gpu:
predicted_patch = predicted_patch.half()
else:
predicted_patch = predicted_patch.cpu().numpy()
aggregated_results[:, lb_x:ub_x, lb_y:ub_y] += predicted_patch
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y] += add_for_nb_of_preds
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
slicer = tuple(
[slice(0, aggregated_results.shape[i]) for i in
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
aggregated_results = aggregated_results[slicer]
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
# computing the class_probabilities by dividing the aggregated result with result_numsamples
class_probabilities = aggregated_results / aggregated_nb_of_predictions
if regions_class_order is None:
predicted_segmentation = class_probabilities.argmax(0)
else:
if all_in_gpu:
class_probabilities_here = class_probabilities.detach().cpu().numpy()
else:
class_probabilities_here = class_probabilities
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
if all_in_gpu:
if verbose: print("copying results to CPU")
if regions_class_order is None:
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
class_probabilities = class_probabilities.detach().cpu().numpy()
if verbose: print("prediction done")
return predicted_segmentation, class_probabilities
def _internal_predict_3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
predicted_segmentation = []
softmax_pred = []
for s in range(x.shape[1]):
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv(
x[:, s], min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
def predict_3D_pseudo3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
pseudo3D_slices: int = 5, all_in_gpu: bool = False,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
assert pseudo3D_slices % 2 == 1, "pseudo3D_slices must be odd"
extra_slices = (pseudo3D_slices - 1) // 2
shp_for_pad = np.array(x.shape)
shp_for_pad[1] = extra_slices
pad = np.zeros(shp_for_pad, dtype=np.float32)
data = np.concatenate((pad, x, pad), 1)
predicted_segmentation = []
softmax_pred = []
for s in range(extra_slices, data.shape[1] - extra_slices):
d = data[:, (s - extra_slices):(s + extra_slices + 1)]
d = d.reshape((-1, d.shape[-2], d.shape[-1]))
pred_seg, softmax_pres = \
self._internal_predict_2D_2Dconv(d, min_size, do_mirroring, mirror_axes,
regions_class_order, pad_border_mode, pad_kwargs, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
def _internal_predict_3D_2Dconv_tiled(self, x: np.ndarray, patch_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), step_size: float = 0.5,
regions_class_order: tuple = None, use_gaussian: bool = False,
pad_border_mode: str = "edge", pad_kwargs: dict =None,
all_in_gpu: bool = False,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
predicted_segmentation = []
softmax_pred = []
for s in range(x.shape[1]):
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv_tiled(
x[:, s], step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian,
pad_border_mode, pad_kwargs, all_in_gpu, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
if __name__ == '__main__':
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.5))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 1))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.1))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 1))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 0.5))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 1))
print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 0.125))
print(SegmentationNetwork._compute_steps_for_sliding_window((123, 54, 123), (246, 162, 369), 0.25))
================================================
FILE: unetr_pp/network_architecture/synapse/__init__.py
================================================
================================================
FILE: unetr_pp/network_architecture/synapse/model_components.py
================================================
from torch import nn
from timm.models.layers import trunc_normal_
from typing import Sequence, Tuple, Union
from monai.networks.layers.utils import get_norm_layer
from monai.utils import optional_import
from unetr_pp.network_architecture.layers import LayerNorm
from unetr_pp.network_architecture.synapse.transformerblock import TransformerBlock
from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock
einops, _ = optional_import("einops")
class UnetrPPEncoder(nn.Module):
def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256],
proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem_layer = nn.Sequential(
get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]),
)
self.downsample_layers.append(stem_layer)
for i in range(3):
downsample_layer = nn.Sequential(
get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads,
dropout_rate=transformer_dropout_rate, pos_embed=True))
self.stages.append(nn.Sequential(*stage_blocks))
self.hidden_states = []
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (LayerNorm, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
hidden_states = []
x = self.downsample_layers[0](x)
x = self.stages[0](x)
hidden_states.append(x)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if i == 3: # Reshape the output of the last stage
x = einops.rearrange(x, "b c h w d -> b (h w d) c")
hidden_states.append(x)
return x, hidden_states
def forward(self, x):
x, hidden_states = self.forward_features(x)
return x, hidden_states
class UnetrUpBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
proj_size: int = 64,
num_heads: int = 4,
out_size: int = 0,
depth: int = 3,
conv_decoder: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of heads inside each EPA module.
out_size: spatial size for each decoder.
depth: number of blocks for the current decoder stage.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
# 4 feature resolution stages, each consisting of multiple residual blocks
self.decoder_block = nn.ModuleList()
# If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper)
if conv_decoder == True:
self.decoder_block.append(
UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1,
norm_name=norm_name, ))
else:
stage_blocks = []
for j in range(depth):
stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads,
dropout_rate=0.15, pos_embed=True))
self.decoder_block.append(nn.Sequential(*stage_blocks))
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, inp, skip):
out = self.transp_conv(inp)
out = out + skip
out = self.decoder_block[0](out)
return out
================================================
FILE: unetr_pp/network_architecture/synapse/transformerblock.py
================================================
import torch.nn as nn
import torch
from unetr_pp.network_architecture.dynunet_block import UnetResBlock
class TransformerBlock(nn.Module):
"""
A transformer block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
input_size: int,
hidden_size: int,
proj_size: int,
num_heads: int,
dropout_rate: float = 0.0,
pos_embed=False,
) -> None:
"""
Args:
input_size: the size of the input for each stage.
hidden_size: dimension of hidden layer.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
pos_embed: bool argument to determine if positional embedding is used.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
print("Hidden size is ", hidden_size)
print("Num heads is ", num_heads)
raise ValueError("hidden_size should be divisible by num_heads.")
self.norm = nn.LayerNorm(hidden_size)
self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True)
self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate)
self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1))
self.pos_embed = None
if pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size))
def forward(self, x):
B, C, H, W, D = x.shape
x = x.reshape(B, C, H * W * D).permute(0, 2, 1)
if self.pos_embed is not None:
x = x + self.pos_embed
attn = x + self.gamma * self.epa_block(self.norm(x))
attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D)
attn = self.conv51(attn_skip)
x = attn_skip + self.conv8(attn)
return x
class EPA(nn.Module):
"""
Efficient Paired Attention Block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
channel_attn_drop=0.1, spatial_attn_drop=0.1):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)
# E and F are projection matrices with shared weights used in spatial attention module to project
# keys and values from HWD-dimension to P-dimension
self.E = self.F = nn.Linear(input_size, proj_size)
self.attn_drop = nn.Dropout(channel_attn_drop)
self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
def forward(self, x):
B, N, C = x.shape
qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
qkvv = qkvv.permute(2, 0, 3, 1, 4)
q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]
q_shared = q_shared.transpose(-2, -1)
k_shared = k_shared.transpose(-2, -1)
v_CA = v_CA.transpose(-2, -1)
v_SA = v_SA.transpose(-2, -1)
k_shared_projected = self.E(k_shared)
v_SA_projected = self.F(v_SA)
q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
k_shared = torch.nn.functional.normalize(k_shared, dim=-1)
attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
attn_CA = attn_CA.softmax(dim=-1)
attn_CA = self.attn_drop(attn_CA)
x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)
attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
attn_SA = attn_SA.softmax(dim=-1)
attn_SA = self.attn_drop_2(attn_SA)
x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)
# Concat fusion
x_SA = self.out_proj(x_SA)
x_CA = self.out_proj2(x_CA)
x = torch.cat((x_SA, x_CA), dim=-1)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature', 'temperature2'}
================================================
FILE: unetr_pp/network_architecture/synapse/unetr_pp_synapse.py
================================================
from torch import nn
from typing import Tuple, Union
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock
from unetr_pp.network_architecture.synapse.model_components import UnetrPPEncoder, UnetrUpBlock
class UNETR_PP(SegmentationNetwork):
"""
UNETR++ based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: [64, 128, 128],
feature_size: int = 16,
hidden_size: int = 256,
num_heads: int = 4,
pos_embed: str = "perceptron", # TODO: Remove the argument
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
depths=None,
dims=None,
conv_op=nn.Conv3d,
do_ds=True,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimension of the last encoder.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
dropout_rate: faction of the input units to drop.
depths: number of blocks for each stage.
dims: number of channel maps for the stages.
conv_op: type of convolution operation.
do_ds: use deep supervision to compute the loss.
Examples::
# for single channel input 4-channel output with patch size of (64, 128, 128), feature size of 16, batch
norm and depths of [3, 3, 3, 3] with output channels [32, 64, 128, 256], 4 heads, and 14 classes with
deep supervision:
>>> net = UNETR_PP(in_channels=1, out_channels=14, img_size=(64, 128, 128), feature_size=16, num_heads=4,
>>> norm_name='batch', depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True)
"""
super().__init__()
if depths is None:
depths = [3, 3, 3, 3]
self.do_ds = do_ds
self.conv_op = conv_op
self.num_classes = out_channels
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.patch_size = (2, 4, 4)
self.feat_size = (
img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages
img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages
img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages
)
self.hidden_size = hidden_size
self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)
self.encoder1 = UnetResBlock(
spatial_dims=3,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=8 * 8 * 8,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=16 * 16 * 16,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=32 * 32 * 32,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=(2, 4, 4),
norm_name=norm_name,
out_size=64 * 128 * 128,
conv_decoder=True,
)
self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
if self.do_ds:
self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def forward(self, x_in):
x_output, hidden_states = self.unetr_pp_encoder(x_in)
convBlock = self.encoder1(x_in)
# Four encoders
enc1 = hidden_states[0]
enc2 = hidden_states[1]
enc3 = hidden_states[2]
enc4 = hidden_states[3]
# Four decoders
dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc3)
dec2 = self.decoder4(dec3, enc2)
dec1 = self.decoder3(dec2, enc1)
out = self.decoder2(dec1, convBlock)
if self.do_ds:
logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
else:
logits = self.out1(out)
return logits
================================================
FILE: unetr_pp/network_architecture/tumor/__init__.py
================================================
================================================
FILE: unetr_pp/network_architecture/tumor/model_components.py
================================================
from torch import nn
from timm.models.layers import trunc_normal_
from typing import Sequence, Tuple, Union
from monai.networks.layers.utils import get_norm_layer
from monai.utils import optional_import
from unetr_pp.network_architecture.layers import LayerNorm
from unetr_pp.network_architecture.tumor.transformerblock import TransformerBlock
from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock
einops, _ = optional_import("einops")
class UnetrPPEncoder(nn.Module):
def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256],
proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=4,
dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem_layer = nn.Sequential(
get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(4, 4, 4), stride=(4, 4, 4),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]),
)
self.downsample_layers.append(stem_layer)
for i in range(3):
downsample_layer = nn.Sequential(
get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2),
dropout=dropout, conv_only=True, ),
get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i],
proj_size=proj_size[i], num_heads=num_heads,
dropout_rate=transformer_dropout_rate, pos_embed=True))
self.stages.append(nn.Sequential(*stage_blocks))
self.hidden_states = []
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (LayerNorm, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
hidden_states = []
x = self.downsample_layers[0](x)
x = self.stages[0](x)
hidden_states.append(x)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
if i == 3: # Reshape the output of the last stage
x = einops.rearrange(x, "b c h w d -> b (h w d) c")
hidden_states.append(x)
return x, hidden_states
def forward(self, x):
x, hidden_states = self.forward_features(x)
return x, hidden_states
class UnetrUpBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
proj_size: int = 64,
num_heads: int = 4,
out_size: int = 0,
depth: int = 3,
conv_decoder: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of heads inside each EPA module.
out_size: spatial size for each decoder.
depth: number of blocks for the current decoder stage.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
# 4 feature resolution stages, each consisting of multiple residual blocks
self.decoder_block = nn.ModuleList()
# If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block
# (see suppl. material in the paper)
if conv_decoder == True:
self.decoder_block.append(
UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1,
norm_name=norm_name, ))
else:
stage_blocks = []
for j in range(depth):
stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels,
proj_size=proj_size, num_heads=num_heads,
dropout_rate=0.1, pos_embed=True))
self.decoder_block.append(nn.Sequential(*stage_blocks))
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, inp, skip):
out = self.transp_conv(inp)
out = out + skip
out = self.decoder_block[0](out)
return out
================================================
FILE: unetr_pp/network_architecture/tumor/transformerblock.py
================================================
import torch.nn as nn
import torch
from unetr_pp.network_architecture.dynunet_block import UnetResBlock
import math
class TransformerBlock(nn.Module):
"""
A transformer block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
input_size: int,
hidden_size: int,
proj_size: int,
num_heads: int,
dropout_rate: float = 0.0,
pos_embed=False,
) -> None:
"""
Args:
input_size: the size of the input for each stage.
hidden_size: dimension of hidden layer.
proj_size: projection size for keys and values in the spatial attention module.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
pos_embed: bool argument to determine if positional embedding is used.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
print("Hidden size is ", hidden_size)
print("Num heads is ", num_heads)
raise ValueError("hidden_size should be divisible by num_heads.")
self.norm = nn.LayerNorm(hidden_size)
self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True)
self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate)
self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch")
self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1))
self.pos_embed = None
if pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size))
def forward(self, x):
B, C, H, W, D = x.shape
x = x.reshape(B, C, H * W * D).permute(0, 2, 1)
if self.pos_embed is not None:
x = x + self.pos_embed
attn = x + self.gamma * self.epa_block(self.norm(x))
attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D)
attn = self.conv51(attn_skip)
x = attn_skip + self.conv8(attn)
return x
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
class EPA(nn.Module):
"""
Efficient Paired Attention Block, based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
channel_attn_drop=0.1, spatial_attn_drop=0.1):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
# qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)
# E and F are projection matrices with shared weights used in spatial attention module to project
# keys and values from HWD-dimension to P-dimension
self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size)))
self.attn_drop = nn.Dropout(channel_attn_drop)
self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
def forward(self, x):
B, N, C = x.shape
qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
qkvv = qkvv.permute(2, 0, 3, 1, 4)
q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]
q_shared = q_shared.transpose(-2, -1)
k_shared = k_shared.transpose(-2, -1)
v_CA = v_CA.transpose(-2, -1)
v_SA = v_SA.transpose(-2, -1)
proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args)
k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF)))
q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
k_shared = torch.nn.functional.normalize(k_shared, dim=-1)
attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
attn_CA = attn_CA.softmax(dim=-1)
attn_CA = self.attn_drop(attn_CA)
x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)
attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2
attn_SA = attn_SA.softmax(dim=-1)
attn_SA = self.attn_drop_2(attn_SA)
x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)
return x_CA + x_SA
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature', 'temperature2'}
================================================
FILE: unetr_pp/network_architecture/tumor/unetr_pp_tumor.py
================================================
from torch import nn
from typing import Tuple, Union
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock
from unetr_pp.network_architecture.tumor.model_components import UnetrPPEncoder, UnetrUpBlock
class UNETR_PP(SegmentationNetwork):
"""
UNETR++ based on: "Shaker et al.,
UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
"""
def __init__(
self,
in_channels: int,
out_channels: int,
feature_size: int = 16,
hidden_size: int = 256,
num_heads: int = 4,
pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
depths=None,
dims=None,
conv_op=nn.Conv3d,
do_ds=True,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimensions of the last encoder.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
dropout_rate: faction of the input units to drop.
depths: number of blocks for each stage.
dims: number of channel maps for the stages.
conv_op: type of convolution operation.
do_ds: use deep supervision to compute the loss.
"""
super().__init__()
if depths is None:
depths = [3, 3, 3, 3]
self.do_ds = do_ds
self.conv_op = conv_op
self.num_classes = out_channels
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
self.feat_size = (4, 4, 4,)
self.hidden_size = hidden_size
self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)
self.encoder1 = UnetResBlock(
spatial_dims=3,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=8*8*8,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=16*16*16,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
out_size=32*32*32,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=3,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=(4, 4, 4),
norm_name=norm_name,
out_size=128*128*128,
conv_decoder=True,
)
self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
if self.do_ds:
self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def forward(self, x_in):
#print("###########reached forward network")
#print("XIN",x_in.shape)
x_output, hidden_states = self.unetr_pp_encoder(x_in)
convBlock = self.encoder1(x_in)
# Four encoders
enc1 = hidden_states[0]
enc2 = hidden_states[1]
enc3 = hidden_states[2]
enc4 = hidden_states[3]
# Four decoders
dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
dec3 = self.decoder5(dec4, enc3)
dec2 = self.decoder4(dec3, enc2)
dec1 = self.decoder3(dec2, enc1)
out = self.decoder2(dec1, convBlock)
if self.do_ds:
logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
else:
logits = self.out1(out)
return logits
================================================
FILE: unetr_pp/paths.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join
# do not modify these unless you know what you are doing
my_output_identifier = "unetr_pp"
default_plans_identifier = "unetr_pp_Plansv2.1"
default_data_identifier = 'unetr_pp_Data_plans_v2.1'
default_trainer = "unetr_pp_trainer_synapse"
"""
PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
"""
base = os.environ['unetr_pp_raw_data_base'] if "unetr_pp_raw_data_base" in os.environ.keys() else None
preprocessing_output_dir = os.environ['unetr_pp_preprocessed'] if "unetr_pp_preprocessed" in os.environ.keys() else None
network_training_output_dir_base = os.path.join(os.environ['RESULTS_FOLDER']) if "RESULTS_FOLDER" in os.environ.keys() else None
if base is not None:
nnFormer_raw_data = join(base, "unetr_pp_raw_data")
nnFormer_cropped_data = join(base, "unetr_pp_cropped_data")
maybe_mkdir_p(nnFormer_raw_data)
maybe_mkdir_p(nnFormer_cropped_data)
else:
print("unetr_pp_raw_data_base is not defined and model can only be used on data for which preprocessed files "
"are already present on your system. model cannot be used for experiment planning and preprocessing like "
"this. If this is not intended, please read run_training_synapse.sh/run_training_acdc.sh "
"for information on how to set this up properly.")
nnFormer_cropped_data = nnFormer_raw_data = None
if preprocessing_output_dir is not None:
maybe_mkdir_p(preprocessing_output_dir)
else:
print("unetr_pp_preprocessed is not defined and model can not be used for preprocessing "
"or training. If this is not intended, please read documentation/setting_up_paths.md for "
"information on how to set this up.")
preprocessing_output_dir = None
if network_training_output_dir_base is not None:
network_training_output_dir = join(network_training_output_dir_base, my_output_identifier)
maybe_mkdir_p(network_training_output_dir)
else:
print("RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or "
"inference. If this is not intended behavior, please read run_training_synapse.sh/run_training_acdc.sh "
"for information on how to set this up.")
network_training_output_dir = None
================================================
FILE: unetr_pp/postprocessing/connected_components.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from copy import deepcopy
from multiprocessing.pool import Pool
import numpy as np
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from scipy.ndimage import label
import SimpleITK as sitk
from unetr_pp.utilities.sitk_stuff import copy_geometry
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
def load_remove_save(input_file: str, output_file: str, for_which_classes: list,
minimum_valid_object_size: dict = None):
# Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must
# match entries in for_which_classes
img_in = sitk.ReadImage(input_file)
img_npy = sitk.GetArrayFromImage(img_in)
volume_per_voxel = float(np.prod(img_in.GetSpacing(), dtype=np.float64))
image, largest_removed, kept_size = remove_all_but_the_largest_connected_component(img_npy, for_which_classes,
volume_per_voxel,
minimum_valid_object_size)
# print(input_file, "kept:", kept_size)
img_out_itk = sitk.GetImageFromArray(image)
img_out_itk = copy_geometry(img_out_itk, img_in)
sitk.WriteImage(img_out_itk, output_file)
return largest_removed, kept_size
def remove_all_but_the_largest_connected_component(image: np.ndarray, for_which_classes: list, volume_per_voxel: float,
minimum_valid_object_size: dict = None):
"""
removes all but the largest connected component, individually for each class
:param image:
:param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
to use all foreground classes together)
:param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
minimum_valid_object_size must match entries in for_which_classes
:return:
"""
if for_which_classes is None:
for_which_classes = np.unique(image)
for_which_classes = for_which_classes[for_which_classes > 0]
assert 0 not in for_which_classes, "cannot remove background"
largest_removed = {}
kept_size = {}
for c in for_which_classes:
if isinstance(c, (list, tuple)):
c = tuple(c) # otherwise it cant be used as key in the dict
mask = np.zeros_like(image, dtype=bool)
for cl in c:
mask[image == cl] = True
else:
mask = image == c
# get labelmap and number of objects
lmap, num_objects = label(mask.astype(int))
# collect object sizes
object_sizes = {}
for object_id in range(1, num_objects + 1):
object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel
largest_removed[c] = None
kept_size[c] = None
if num_objects > 0:
# we always keep the largest object. We could also consider removing the largest object if it is smaller
# than minimum_valid_object_size in the future but we don't do that now.
maximum_size = max(object_sizes.values())
kept_size[c] = maximum_size
for object_id in range(1, num_objects + 1):
# we only remove objects that are not the largest
if object_sizes[object_id] != maximum_size:
# we only remove objects that are smaller than minimum_valid_object_size
remove = True
if minimum_valid_object_size is not None:
remove = object_sizes[object_id] < minimum_valid_object_size[c]
if remove:
image[(lmap == object_id) & mask] = 0
if largest_removed[c] is None:
largest_removed[c] = object_sizes[object_id]
else:
largest_removed[c] = max(largest_removed[c], object_sizes[object_id])
return image, largest_removed, kept_size
def load_postprocessing(json_file):
'''
loads the relevant part of the pkl file that is needed for applying postprocessing
:param pkl_file:
:return:
'''
a = load_json(json_file)
if 'min_valid_object_sizes' in a.keys():
min_valid_object_sizes = ast.literal_eval(a['min_valid_object_sizes'])
else:
min_valid_object_sizes = None
return a['for_which_classes'], min_valid_object_sizes
def determine_postprocessing(base, gt_labels_folder, raw_subfolder_name="validation_raw",
temp_folder="temp",
final_subf_name="validation_final", processes=default_num_threads,
dice_threshold=0, debug=False,
advanced_postprocessing=False,
pp_filename="postprocessing.json"):
"""
:param base:
:param gt_labels_folder: subfolder of base with niftis of ground truth labels
:param raw_subfolder_name: subfolder of base with niftis of predicted (non-postprocessed) segmentations
:param temp_folder: used to store temporary data, will be deleted after we are done here undless debug=True
:param final_subf_name: final results will be stored here (subfolder of base)
:param processes:
:param dice_threshold: only apply postprocessing if results is better than old_result+dice_threshold (can be used as eps)
:param debug: if True then the temporary files will not be deleted
:return:
"""
# lets see what classes are in the dataset
classes = [int(i) for i in load_json(join(base, raw_subfolder_name, "summary.json"))['results']['mean'].keys() if
int(i) != 0]
folder_all_classes_as_fg = join(base, temp_folder + "_allClasses")
folder_per_class = join(base, temp_folder + "_perClass")
if isdir(folder_all_classes_as_fg):
shutil.rmtree(folder_all_classes_as_fg)
if isdir(folder_per_class):
shutil.rmtree(folder_per_class)
# multiprocessing rules
p = Pool(processes)
assert isfile(join(base, raw_subfolder_name, "summary.json")), "join(base, raw_subfolder_name) does not " \
"contain a summary.json"
# these are all the files we will be dealing with
fnames = subfiles(join(base, raw_subfolder_name), suffix=".nii.gz", join=False)
# make output and temp dir
maybe_mkdir_p(folder_all_classes_as_fg)
maybe_mkdir_p(folder_per_class)
maybe_mkdir_p(join(base, final_subf_name))
pp_results = {}
pp_results['dc_per_class_raw'] = {}
pp_results['dc_per_class_pp_all'] = {} # dice scores after treating all foreground classes as one
pp_results['dc_per_class_pp_per_class'] = {} # dice scores after removing everything except larges cc
# independently for each class after we already did dc_per_class_pp_all
pp_results['for_which_classes'] = []
pp_results['min_valid_object_sizes'] = {}
validation_result_raw = load_json(join(base, raw_subfolder_name, "summary.json"))['results']
pp_results['num_samples'] = len(validation_result_raw['all'])
validation_result_raw = validation_result_raw['mean']
if advanced_postprocessing:
# first treat all foreground classes as one and remove all but the largest foreground connected component
results = []
for f in fnames:
predicted_segmentation = join(base, raw_subfolder_name, f)
# now remove all but the largest connected component for each class
output_file = join(folder_all_classes_as_fg, f)
results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,)),)))
results = [i.get() for i in results]
# aggregate max_size_removed and min_size_kept
max_size_removed = {}
min_size_kept = {}
for tmp in results:
mx_rem, min_kept = tmp[0]
for k in mx_rem:
if mx_rem[k] is not None:
if max_size_removed.get(k) is None:
max_size_removed[k] = mx_rem[k]
else:
max_size_removed[k] = max(max_size_removed[k], mx_rem[k])
for k in min_kept:
if min_kept[k] is not None:
if min_size_kept.get(k) is None:
min_size_kept[k] = min_kept[k]
else:
min_size_kept[k] = min(min_size_kept[k], min_kept[k])
print("foreground vs background, smallest valid object size was", min_size_kept[tuple(classes)])
print("removing only objects smaller than that...")
else:
min_size_kept = None
# we need to rerun the step from above, now with the size constraint
pred_gt_tuples = []
results = []
# first treat all foreground classes as one and remove all but the largest foreground connected component
for f in fnames:
predicted_segmentation = join(base, raw_subfolder_name, f)
# now remove all but the largest connected component for each class
output_file = join(folder_all_classes_as_fg, f)
results.append(
p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,), min_size_kept),)))
pred_gt_tuples.append([output_file, join(gt_labels_folder, f)])
_ = [i.get() for i in results]
# evaluate postprocessed predictions
_ = aggregate_scores(pred_gt_tuples, labels=classes,
json_output_file=join(folder_all_classes_as_fg, "summary.json"),
json_author="Fabian", num_threads=processes)
# now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
# that if a single class got worse as a result we won't do this. We can change this in the future but right now I
# prefer to do it this way
validation_result_PP_test = load_json(join(folder_all_classes_as_fg, "summary.json"))['results']['mean']
for c in classes:
dc_raw = validation_result_raw[str(c)]['Dice']
dc_pp = validation_result_PP_test[str(c)]['Dice']
pp_results['dc_per_class_raw'][str(c)] = dc_raw
pp_results['dc_per_class_pp_all'][str(c)] = dc_pp
# true if new is better
do_fg_cc = False
comp = [pp_results['dc_per_class_pp_all'][str(cl)] > (pp_results['dc_per_class_raw'][str(cl)] + dice_threshold) for
cl in classes]
before = np.mean([pp_results['dc_per_class_raw'][str(cl)] for cl in classes])
after = np.mean([pp_results['dc_per_class_pp_all'][str(cl)] for cl in classes])
print("Foreground vs background")
print("before:", before)
print("after: ", after)
if any(comp):
# at least one class improved - yay!
# now check if another got worse
# true if new is worse
any_worse = any(
[pp_results['dc_per_class_pp_all'][str(cl)] < pp_results['dc_per_class_raw'][str(cl)] for cl in classes])
if not any_worse:
pp_results['for_which_classes'].append(classes)
if min_size_kept is not None:
pp_results['min_valid_object_sizes'].update(deepcopy(min_size_kept))
do_fg_cc = True
print("Removing all but the largest foreground region improved results!")
print('for_which_classes', classes)
print('min_valid_object_sizes', min_size_kept)
else:
# did not improve things - don't do it
pass
if len(classes) > 1:
# now depending on whether we do remove all but the largest foreground connected component we define the source dir
# for the next one to be the raw or the temp dir
if do_fg_cc:
source = folder_all_classes_as_fg
else:
source = join(base, raw_subfolder_name)
if advanced_postprocessing:
# now run this for each class separately
results = []
for f in fnames:
predicted_segmentation = join(source, f)
output_file = join(folder_per_class, f)
results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes),)))
results = [i.get() for i in results]
# aggregate max_size_removed and min_size_kept
max_size_removed = {}
min_size_kept = {}
for tmp in results:
mx_rem, min_kept = tmp[0]
for k in mx_rem:
if mx_rem[k] is not None:
if max_size_removed.get(k) is None:
max_size_removed[k] = mx_rem[k]
else:
max_size_removed[k] = max(max_size_removed[k], mx_rem[k])
for k in min_kept:
if min_kept[k] is not None:
if min_size_kept.get(k) is None:
min_size_kept[k] = min_kept[k]
else:
min_size_kept[k] = min(min_size_kept[k], min_kept[k])
print("classes treated separately, smallest valid object sizes are")
print(min_size_kept)
print("removing only objects smaller than that")
else:
min_size_kept = None
# rerun with the size thresholds from above
pred_gt_tuples = []
results = []
for f in fnames:
predicted_segmentation = join(source, f)
output_file = join(folder_per_class, f)
results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes, min_size_kept),)))
pred_gt_tuples.append([output_file, join(gt_labels_folder, f)])
_ = [i.get() for i in results]
# evaluate postprocessed predictions
_ = aggregate_scores(pred_gt_tuples, labels=classes,
json_output_file=join(folder_per_class, "summary.json"),
json_author="Fabian", num_threads=processes)
if do_fg_cc:
old_res = deepcopy(validation_result_PP_test)
else:
old_res = validation_result_raw
# these are the new dice scores
validation_result_PP_test = load_json(join(folder_per_class, "summary.json"))['results']['mean']
for c in classes:
dc_raw = old_res[str(c)]['Dice']
dc_pp = validation_result_PP_test[str(c)]['Dice']
pp_results['dc_per_class_pp_per_class'][str(c)] = dc_pp
print(c)
print("before:", dc_raw)
print("after: ", dc_pp)
if dc_pp > (dc_raw + dice_threshold):
pp_results['for_which_classes'].append(int(c))
if min_size_kept is not None:
pp_results['min_valid_object_sizes'].update({c: min_size_kept[c]})
print("Removing all but the largest region for class %d improved results!" % c)
print('min_valid_object_sizes', min_size_kept)
else:
print("Only one class present, no need to do each class separately as this is covered in fg vs bg")
if not advanced_postprocessing:
pp_results['min_valid_object_sizes'] = None
print("done")
print("for which classes:")
print(pp_results['for_which_classes'])
print("min_object_sizes")
print(pp_results['min_valid_object_sizes'])
pp_results['validation_raw'] = raw_subfolder_name
pp_results['validation_final'] = final_subf_name
# now that we have a proper for_which_classes, apply that
pred_gt_tuples = []
results = []
for f in fnames:
predicted_segmentation = join(base, raw_subfolder_name, f)
# now remove all but the largest connected component for each class
output_file = join(base, final_subf_name, f)
results.append(p.starmap_async(load_remove_save, (
(predicted_segmentation, output_file, pp_results['for_which_classes'],
pp_results['min_valid_object_sizes']),)))
pred_gt_tuples.append([output_file,
join(gt_labels_folder, f)])
_ = [i.get() for i in results]
# evaluate postprocessed predictions
_ = aggregate_scores(pred_gt_tuples, labels=classes,
json_output_file=join(base, final_subf_name, "summary.json"),
json_author="Fabian", num_threads=processes)
pp_results['min_valid_object_sizes'] = str(pp_results['min_valid_object_sizes'])
save_json(pp_results, join(base, pp_filename))
# delete temp
if not debug:
shutil.rmtree(folder_per_class)
shutil.rmtree(folder_all_classes_as_fg)
p.close()
p.join()
print("done")
def apply_postprocessing_to_folder(input_folder: str, output_folder: str, for_which_classes: list,
min_valid_object_size:dict=None, num_processes=8):
"""
applies removing of all but the largest connected component to all niftis in a folder
:param min_valid_object_size:
:param min_valid_object_size:
:param input_folder:
:param output_folder:
:param for_which_classes:
:param num_processes:
:return:
"""
maybe_mkdir_p(output_folder)
p = Pool(num_processes)
nii_files = subfiles(input_folder, suffix=".nii.gz", join=False)
input_files = [join(input_folder, i) for i in nii_files]
out_files = [join(output_folder, i) for i in nii_files]
results = p.starmap_async(load_remove_save, zip(input_files, out_files, [for_which_classes] * len(input_files),
[min_valid_object_size] * len(input_files)))
res = results.get()
p.close()
p.join()
if __name__ == "__main__":
input_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor"
output_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor_postprocessed"
for_which_classes = [(1, 2), ]
apply_postprocessing_to_folder(input_folder, output_folder, for_which_classes)
================================================
FILE: unetr_pp/postprocessing/consolidate_all_for_paper.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unetr_pp.utilities.folder_names import get_output_folder_name
def get_datasets():
configurations_all = {
"Task01_BrainTumour": ("3d_fullres", "2d"),
"Task02_Heart": ("3d_fullres", "2d",),
"Task03_Liver": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task04_Hippocampus": ("3d_fullres", "2d",),
"Task05_Prostate": ("3d_fullres", "2d",),
"Task06_Lung": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task07_Pancreas": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task08_HepaticVessel": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task09_Spleen": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task10_Colon": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
"Task48_KiTS_clean": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d"),
"Task27_ACDC": ("3d_fullres", "2d",),
"Task24_Promise": ("3d_fullres", "2d",),
"Task35_ISBILesionSegmentation": ("3d_fullres", "2d",),
"Task38_CHAOS_Task_3_5_Variant2": ("3d_fullres", "2d",),
"Task29_LITS": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",),
"Task17_AbdominalOrganSegmentation": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",),
"Task55_SegTHOR": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",),
"Task56_VerSe": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",),
}
return configurations_all
def get_commands(configurations, regular_trainer="nnFormerTrainerV2", cascade_trainer="nnFormerTrainerV2CascadeFullRes",
plans="nnFormerPlansv2.1"):
node_pool = ["hdf18-gpu%02.0d" % i for i in range(1, 21)] + ["hdf19-gpu%02.0d" % i for i in range(1, 8)] + ["hdf19-gpu%02.0d" % i for i in range(11, 16)]
ctr = 0
for task in configurations:
models = configurations[task]
for m in models:
if m == "3d_cascade_fullres":
trainer = cascade_trainer
else:
trainer = regular_trainer
folder = get_output_folder_name(m, task, trainer, plans, overwrite_training_output_dir="/datasets/datasets_fabian/results/nnFormer")
node = node_pool[ctr % len(node_pool)]
print("bsub -m %s -q gputest -L /bin/bash \"source ~/.bashrc && python postprocessing/"
"consolidate_postprocessing.py -f" % node, folder, "\"")
ctr += 1
================================================
FILE: unetr_pp/postprocessing/consolidate_postprocessing.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from typing import Tuple
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.postprocessing.connected_components import determine_postprocessing
import argparse
def collect_cv_niftis(cv_folder: str, output_folder: str, validation_folder_name: str = 'validation_raw',
folds: tuple = (0, 1, 2, 3, 4)):
validation_raw_folders = [join(cv_folder, "fold_%d" % i, validation_folder_name) for i in folds]
exist = [isdir(i) for i in validation_raw_folders]
if not all(exist):
raise RuntimeError("some folds are missing. Please run the full 5-fold cross-validation. "
"The following folds seem to be missing: %s" %
[i for j, i in enumerate(folds) if not exist[j]])
# now copy all raw niftis into cv_niftis_raw
maybe_mkdir_p(output_folder)
for f in folds:
niftis = subfiles(validation_raw_folders[f], suffix=".nii.gz")
for n in niftis:
shutil.copy(n, join(output_folder))
def consolidate_folds(output_folder_base, validation_folder_name: str = 'validation_raw',
advanced_postprocessing: bool = False, folds: Tuple[int] = (0, 1, 2, 3, 4)):
"""
Used to determine the postprocessing for an experiment after all five folds have been completed. In the validation of
each fold, the postprocessing can only be determined on the cases within that fold. This can result in different
postprocessing decisions for different folds. In the end, we can only decide for one postprocessing per experiment,
so we have to rerun it
:param folds:
:param advanced_postprocessing:
:param output_folder_base:experiment output folder (fold_0, fold_1, etc must be subfolders of the given folder)
:param validation_folder_name: dont use this
:return:
"""
output_folder_raw = join(output_folder_base, "cv_niftis_raw")
if isdir(output_folder_raw):
shutil.rmtree(output_folder_raw)
output_folder_gt = join(output_folder_base, "gt_niftis")
collect_cv_niftis(output_folder_base, output_folder_raw, validation_folder_name,
folds)
num_niftis_gt = len(subfiles(join(output_folder_base, "gt_niftis"), suffix='.nii.gz'))
# count niftis in there
num_niftis = len(subfiles(output_folder_raw, suffix='.nii.gz'))
if num_niftis != num_niftis_gt:
raise AssertionError("If does not seem like you trained all the folds! Train all folds first!")
# load a summary file so that we can know what class labels to expect
summary_fold0 = load_json(join(output_folder_base, "fold_0", validation_folder_name, "summary.json"))['results'][
'mean']
classes = [int(i) for i in summary_fold0.keys()]
niftis = subfiles(output_folder_raw, join=False, suffix=".nii.gz")
test_pred_pairs = [(join(output_folder_gt, i), join(output_folder_raw, i)) for i in niftis]
# determine_postprocessing needs a summary.json file in the folder where the raw predictions are. We could compute
# that from the summary files of the five folds but I am feeling lazy today
aggregate_scores(test_pred_pairs, labels=classes, json_output_file=join(output_folder_raw, "summary.json"),
num_threads=default_num_threads)
determine_postprocessing(output_folder_base, output_folder_gt, 'cv_niftis_raw',
final_subf_name="cv_niftis_postprocessed", processes=default_num_threads,
advanced_postprocessing=advanced_postprocessing)
# determine_postprocessing will create a postprocessing.json file that can be used for inference
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("-f", type=str, required=True, help="experiment output folder (fold_0, fold_1, "
"etc must be subfolders of the given folder)")
args = argparser.parse_args()
folder = args.f
consolidate_folds(folder)
================================================
FILE: unetr_pp/postprocessing/consolidate_postprocessing_simple.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from unetr_pp.postprocessing.consolidate_postprocessing import consolidate_folds
from unetr_pp.utilities.folder_names import get_output_folder_name
from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name
from unetr_pp.paths import default_cascade_trainer, default_trainer, default_plans_identifier
def main():
argparser = argparse.ArgumentParser(usage="Used to determine the postprocessing for a trained model. Useful for "
"when the best configuration (2d, 3d_fullres etc) as selected manually.")
argparser.add_argument("-m", type=str, required=True, help="U-Net model (2d, 3d_lowres, 3d_fullres or "
"3d_cascade_fullres)")
argparser.add_argument("-t", type=str, required=True, help="Task name or id")
argparser.add_argument("-tr", type=str, required=False, default=None,
help="nnFormerTrainer class. Default: %s, unless 3d_cascade_fullres "
"(then it's %s)" % (default_trainer, default_cascade_trainer))
argparser.add_argument("-pl", type=str, required=False, default=default_plans_identifier,
help="Plans name, Default=%s" % default_plans_identifier)
argparser.add_argument("-val", type=str, required=False, default="validation_raw",
help="Validation folder name. Default: validation_raw")
args = argparser.parse_args()
model = args.m
task = args.t
trainer = args.tr
plans = args.pl
val = args.val
if not task.startswith("Task"):
task_id = int(task)
task = convert_id_to_task_name(task_id)
if trainer is None:
if model == "3d_cascade_fullres":
trainer = "nnFormerTrainerV2CascadeFullRes"
else:
trainer = "nnFormerTrainerV2"
folder = get_output_folder_name(model, task, trainer, plans, None)
consolidate_folds(folder, val)
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/preprocessing/cropping.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import SimpleITK as sitk
import numpy as np
import shutil
from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing import Pool
from collections import OrderedDict
def create_nonzero_mask(data):
from scipy.ndimage import binary_fill_holes
assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
def get_bbox_from_mask(mask, outside_value=0):
mask_voxel_coords = np.where(mask != outside_value)
minzidx = int(np.min(mask_voxel_coords[0]))
maxzidx = int(np.max(mask_voxel_coords[0])) + 1
minxidx = int(np.min(mask_voxel_coords[1]))
maxxidx = int(np.max(mask_voxel_coords[1])) + 1
minyidx = int(np.min(mask_voxel_coords[2]))
maxyidx = int(np.max(mask_voxel_coords[2])) + 1
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
def crop_to_bbox(image, bbox):
assert len(image.shape) == 3, "only supports 3d images"
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
def get_case_identifier(case):
case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5]
return case_identifier
def get_case_identifier_from_npz(case):
case_identifier = case.split("/")[-1][:-4]
return case_identifier
def load_case_from_list_of_files(data_files, seg_file=None):
assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple"
properties = OrderedDict()
data_itk = [sitk.ReadImage(f) for f in data_files]
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
properties["list_of_data_files"] = data_files
properties["seg_file"] = seg_file
properties["itk_origin"] = data_itk[0].GetOrigin()
properties["itk_spacing"] = data_itk[0].GetSpacing()
properties["itk_direction"] = data_itk[0].GetDirection()
data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk])
if seg_file is not None:
seg_itk = sitk.ReadImage(seg_file)
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
else:
seg_npy = None
return data_npy.astype(np.float32), seg_npy, properties
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
:param data:
:param seg:
:param nonzero_label: this will be written into the segmentation map
:return:
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask, 0)
cropped_data = []
for c in range(data.shape[0]):
cropped = crop_to_bbox(data[c], bbox)
cropped_data.append(cropped[None])
data = np.vstack(cropped_data)
if seg is not None:
cropped_seg = []
for c in range(seg.shape[0]):
cropped = crop_to_bbox(seg[c], bbox)
cropped_seg.append(cropped[None])
seg = np.vstack(cropped_seg)
nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(int)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask
return data, seg, bbox
def get_patient_identifiers_from_cropped_files(folder):
return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")]
class ImageCropper(object):
def __init__(self, num_threads, output_folder=None):
"""
This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask.
In the case of BRaTS and ISLES data this results in a significant reduction in image size
:param num_threads:
:param output_folder: whete to store the cropped data
:param list_of_files:
"""
self.output_folder = output_folder
self.num_threads = num_threads
if self.output_folder is not None:
maybe_mkdir_p(self.output_folder)
@staticmethod
def crop(data, properties, seg=None):
shape_before = data.shape
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
print("before crop:", shape_before, "after crop:", shape_after, "spacing:",
np.array(properties["original_spacing"]), "\n")
properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg)
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
return data, seg, properties
@staticmethod
def crop_from_list_of_files(data_files, seg_file=None):
data, seg, properties = load_case_from_list_of_files(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
def load_crop_save(self, case, case_identifier, overwrite_existing=False):
try:
print(case_identifier)
if overwrite_existing \
or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier))
or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))):
data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1])
all_data = np.vstack((data, seg))
np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data)
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
except Exception as e:
print("Exception in", case_identifier, ":")
print(e)
raise e
def get_list_of_cropped_files(self):
return subfiles(self.output_folder, join=True, suffix=".npz")
def get_patient_identifiers_from_cropped_files(self):
return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()]
def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None):
"""
also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation
on the cluster
:param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]]
:param overwrite_existing:
:param output_folder:
:return:
"""
if output_folder is not None:
self.output_folder = output_folder
output_folder_gt = os.path.join(self.output_folder, "gt_segmentations")
maybe_mkdir_p(output_folder_gt)
for j, case in enumerate(list_of_files):
if case[-1] is not None:
shutil.copy(case[-1], output_folder_gt)
list_of_args = []
for j, case in enumerate(list_of_files):
case_identifier = get_case_identifier(case)
list_of_args.append((case, case_identifier, overwrite_existing))
p = Pool(self.num_threads)
p.starmap(self.load_crop_save, list_of_args)
p.close()
p.join()
def load_properties(self, case_identifier):
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f:
properties = pickle.load(f)
return properties
def save_properties(self, case_identifier, properties):
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
================================================
FILE: unetr_pp/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from unetr_pp.preprocessing.preprocessing import PreprocessorFor2D, resample_patient
class GenericPreprocessor_scale_uint8_to_0_1(PreprocessorFor2D):
"""
For RGB images with a value range of [0, 255]. This preprocessor overwrites the default normalization scheme by
normalizing intensity values through a simple division by 255 which rescales them to [0, 1]
NOTE THAT THIS INHERITS FROM PreprocessorFor2D, SO ITS WRITTEN FOR 2D ONLY! WHEN CREATING A PREPROCESSOR FOR 3D
DATA, USE GenericPreprocessor AS PARENT!
"""
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
############ THIS PART IS IDENTICAL TO PARENT CLASS ################
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
target_spacing[0] = original_spacing_transposed[0]
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
print("normalization...")
############ HERE IS WHERE WE START CHANGING THINGS!!!!!!!################
# this is where the normalization takes place. We ignore use_nonzero_mask and normalization_scheme_per_modality
for c in range(len(data)):
data[c] = data[c].astype(np.float32) / 255.
return data, seg, properties
================================================
FILE: unetr_pp/preprocessing/preprocessing.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from copy import deepcopy
from batchgenerators.augmentations.utils import resize_segmentation
from unetr_pp.configuration import default_num_threads, RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD
from unetr_pp.preprocessing.cropping import get_case_identifier_from_npz, ImageCropper
from skimage.transform import resize
from scipy.ndimage.interpolation import map_coordinates
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from multiprocessing.pool import Pool
def get_do_separate_z(spacing, anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD):
do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
return do_separate_z
def get_lowres_axis(new_spacing):
axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
return axis
def resample_patient(data, seg, original_spacing, target_spacing, order_data=3, order_seg=0, force_separate_z=False,
cval_data=0, cval_seg=-1, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD):
"""
:param cval_seg:
:param cval_data:
:param data:
:param seg:
:param original_spacing:
:param target_spacing:
:param order_data:
:param order_seg:
:param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always
/never resample along z separately
:param order_z_seg: only applies if do_separate_z is True
:param order_z_data: only applies if do_separate_z is True
:param separate_z_anisotropy_threshold: if max_spacing > separate_z_anisotropy_threshold * min_spacing (per axis)
then resample along lowres axis with order_z_data/order_z_seg instead of order_data/order_seg
:return:
"""
assert not ((data is None) and (seg is None))
if data is not None:
assert len(data.shape) == 4, "data must be c x y z"
if seg is not None:
assert len(seg.shape) == 4, "seg must be c x y z"
if data is not None:
shape = np.array(data[0].shape)
else:
shape = np.array(seg[0].shape)
new_shape = np.round(((np.array(original_spacing) / np.array(target_spacing)).astype(float) * shape)).astype(int)
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
axis = get_lowres_axis(original_spacing)
else:
axis = None
else:
if get_do_separate_z(original_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(original_spacing)
elif get_do_separate_z(target_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(target_spacing)
else:
do_separate_z = False
axis = None
if axis is not None:
if len(axis) == 3:
# every axis has the spacing, this should never happen, why is this code here?
do_separate_z = False
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
else:
pass
if data is not None:
data_reshaped = resample_data_or_seg(data, new_shape, False, axis, order_data, do_separate_z, cval=cval_data,
order_z=order_z_data)
else:
data_reshaped = None
if seg is not None:
seg_reshaped = resample_data_or_seg(seg, new_shape, True, axis, order_seg, do_separate_z, cval=cval_seg,
order_z=order_z_seg)
else:
seg_reshaped = None
return data_reshaped, seg_reshaped
def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, do_separate_z=False, cval=0, order_z=0):
"""
separate_z=True will resample with order 0 along z
:param data:
:param new_shape:
:param is_seg:
:param axis:
:param order:
:param do_separate_z:
:param cval:
:param order_z: only applies if do_separate_z is True
:return:
"""
assert len(data.shape) == 4, "data must be (c, x, y, z)"
if is_seg:
resize_fn = resize_segmentation
kwargs = OrderedDict()
else:
resize_fn = resize
kwargs = {'mode': 'edge', 'anti_aliasing': False}
dtype_data = data.dtype
data = data.astype(float)
shape = np.array(data[0].shape)
new_shape = np.array(new_shape)
if np.any(shape != new_shape):
if do_separate_z:
print("separate z, order in z is", order_z, "order inplane is", order)
assert len(axis) == 1, "only one anisotropic axis supported"
axis = axis[0]
if axis == 0:
new_shape_2d = new_shape[1:]
elif axis == 1:
new_shape_2d = new_shape[[0, 2]]
else:
new_shape_2d = new_shape[:-1]
reshaped_final_data = []
for c in range(data.shape[0]):
reshaped_data = []
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, cval=cval, **kwargs))
elif axis == 1:
reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, cval=cval, **kwargs))
else:
reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, cval=cval,
**kwargs))
reshaped_data = np.stack(reshaped_data, axis)
if shape[axis] != new_shape[axis]:
# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_data.shape
row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
dim_scale = float(orig_dim) / dim
map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
map_rows = row_scale * (map_rows + 0.5) - 0.5
map_cols = col_scale * (map_cols + 0.5) - 0.5
map_dims = dim_scale * (map_dims + 0.5) - 0.5
coord_map = np.array([map_rows, map_cols, map_dims])
if not is_seg or order_z == 0:
reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z, cval=cval,
mode='nearest')[None])
else:
unique_labels = np.unique(reshaped_data)
reshaped = np.zeros(new_shape, dtype=dtype_data)
for i, cl in enumerate(unique_labels):
reshaped_multihot = np.round(
map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
cval=cval, mode='nearest'))
reshaped[reshaped_multihot > 0.5] = cl
reshaped_final_data.append(reshaped[None])
else:
reshaped_final_data.append(reshaped_data[None])
reshaped_final_data = np.vstack(reshaped_final_data)
else:
print("no separate z, order", order)
reshaped = []
for c in range(data.shape[0]):
reshaped.append(resize_fn(data[c], new_shape, order, cval=cval, **kwargs)[None])
reshaped_final_data = np.vstack(reshaped)
return reshaped_final_data.astype(dtype_data)
else:
print("no resampling necessary")
return data
class GenericPreprocessor(object):
def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None):
"""
:param normalization_scheme_per_modality: dict {0:'nonCT'}
:param use_nonzero_mask: {0:False}
:param intensityproperties:
"""
self.transpose_forward = transpose_forward
self.intensityproperties = intensityproperties
self.normalization_scheme_per_modality = normalization_scheme_per_modality
self.use_nonzero_mask = use_nonzero_mask
self.resample_separate_z_anisotropy_threshold = RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD
@staticmethod
def load_cropped(cropped_output_dir, case_identifier):
all_data = np.load(os.path.join(cropped_output_dir, "%s.npz" % case_identifier))['data']
data = all_data[:-1].astype(np.float32)
seg = all_data[-1:]
with open(os.path.join(cropped_output_dir, "%s.pkl" % case_identifier), 'rb') as f:
properties = pickle.load(f)
return data, seg, properties
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
"""
data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
(spacing etc)
:param data:
:param target_spacing:
:param properties:
:param seg:
:param force_separate_z:
:return:
"""
# target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
# data, seg are already transposed. Double check this using the properties
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
# remove nans
data[np.isnan(data)] = 0
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
for c in range(len(data)):
scheme = self.normalization_scheme_per_modality[c]
if scheme == "CT":
# clip to lb and ub from train data foreground and use foreground mn and sd from training data
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensityproperties[c]['mean']
std_intensity = self.intensityproperties[c]['sd']
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
data[c] = np.clip(data[c], lower_bound, upper_bound)
data[c] = (data[c] - mean_intensity) / std_intensity
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
elif scheme == "CT2":
# clip to lb and ub from train data foreground, use mn and sd form each case for normalization
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
mask = (data[c] > lower_bound) & (data[c] < upper_bound)
data[c] = np.clip(data[c], lower_bound, upper_bound)
mn = data[c][mask].mean()
sd = data[c][mask].std()
data[c] = (data[c] - mn) / sd
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
else:
if use_nonzero_mask[c]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
data[c][mask == 0] = 0
return data, seg, properties
def preprocess_test_case(self, data_files, target_spacing, seg_file=None, force_separate_z=None):
data, seg, properties = ImageCropper.crop_from_list_of_files(data_files, seg_file)
data = data.transpose((0, *[i + 1 for i in self.transpose_forward]))
seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward]))
data, seg, properties = self.resample_and_normalize(data, target_spacing, properties, seg,
force_separate_z=force_separate_z)
return data.astype(np.float32), seg, properties
def _run_internal(self, target_spacing, case_identifier, output_folder_stage, cropped_output_dir, force_separate_z,
all_classes):
data, seg, properties = self.load_cropped(cropped_output_dir, case_identifier)
data = data.transpose((0, *[i + 1 for i in self.transpose_forward]))
seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward]))
data, seg, properties = self.resample_and_normalize(data, target_spacing,
properties, seg, force_separate_z)
all_data = np.vstack((data, seg)).astype(np.float32)
# we need to find out where the classes are and sample some random locations
# let's do 10.000 samples per class
# seed this for reproducibility!
num_samples = 10000
min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too sparse
rndst = np.random.RandomState(1234)
class_locs = {}
for c in all_classes:
all_locs = np.argwhere(all_data[-1] == c)
if len(all_locs) == 0:
class_locs[c] = []
continue
target_num_samples = min(num_samples, len(all_locs))
target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
class_locs[c] = selected
print(c, target_num_samples)
properties['class_locations'] = class_locs
print("saving: ", os.path.join(output_folder_stage, "%s.npz" % case_identifier))
np.savez_compressed(os.path.join(output_folder_stage, "%s.npz" % case_identifier),
data=all_data.astype(np.float32))
with open(os.path.join(output_folder_stage, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
num_threads=default_num_threads, force_separate_z=None):
"""
:param target_spacings: list of lists [[1.25, 1.25, 5]]
:param input_folder_with_cropped_npz: dim: c, x, y, z | npz_file['data'] np.savez_compressed(fname.npz, data=arr)
:param output_folder:
:param num_threads:
:param force_separate_z: None
:return:
"""
print("Initializing to run preprocessing")
print("npz folder:", input_folder_with_cropped_npz)
print("output_folder:", output_folder)
list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True)
maybe_mkdir_p(output_folder)
num_stages = len(target_spacings)
if not isinstance(num_threads, (list, tuple, np.ndarray)):
num_threads = [num_threads] * num_stages
assert len(num_threads) == num_stages
# we need to know which classes are present in this dataset so that we can precompute where these classes are
# located. This is needed for oversampling foreground
all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes']
for i in range(num_stages):
all_args = []
output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i)
maybe_mkdir_p(output_folder_stage)
spacing = target_spacings[i]
for j, case in enumerate(list_of_cropped_npz_files):
case_identifier = get_case_identifier_from_npz(case)
args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes
all_args.append(args)
p = Pool(num_threads[i])
p.starmap(self._run_internal, all_args)
p.close()
p.join()
class Preprocessor3DDifferentResampling(GenericPreprocessor):
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
"""
data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
(spacing etc)
:param data:
:param target_spacing:
:param properties:
:param seg:
:param force_separate_z:
:return:
"""
# target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
# data, seg are already transposed. Double check this using the properties
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
# remove nans
data[np.isnan(data)] = 0
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=3, order_z_seg=1,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
for c in range(len(data)):
scheme = self.normalization_scheme_per_modality[c]
if scheme == "CT":
# clip to lb and ub from train data foreground and use foreground mn and sd from training data
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensityproperties[c]['mean']
std_intensity = self.intensityproperties[c]['sd']
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
data[c] = np.clip(data[c], lower_bound, upper_bound)
data[c] = (data[c] - mean_intensity) / std_intensity
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
elif scheme == "CT2":
# clip to lb and ub from train data foreground, use mn and sd form each case for normalization
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
mask = (data[c] > lower_bound) & (data[c] < upper_bound)
data[c] = np.clip(data[c], lower_bound, upper_bound)
mn = data[c][mask].mean()
sd = data[c][mask].std()
data[c] = (data[c] - mn) / sd
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
else:
if use_nonzero_mask[c]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
data[c][mask == 0] = 0
return data, seg, properties
class Preprocessor3DBetterResampling(GenericPreprocessor):
"""
This preprocessor always uses force_separate_z=False. It does resampling to the target spacing with third
order spline for data (just like GenericPreprocessor) and seg (unlike GenericPreprocessor). It never does separate
resampling in z.
"""
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=False):
"""
data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
(spacing etc)
:param data:
:param target_spacing:
:param properties:
:param seg:
:param force_separate_z:
:return:
"""
if force_separate_z is not False:
print("WARNING: Preprocessor3DBetterResampling always uses force_separate_z=False. "
"You specified %s. Your choice is overwritten" % str(force_separate_z))
force_separate_z = False
# be safe
assert force_separate_z is False
# target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
# data, seg are already transposed. Double check this using the properties
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
# remove nans
data[np.isnan(data)] = 0
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 3,
force_separate_z=force_separate_z, order_z_data=99999, order_z_seg=99999,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
for c in range(len(data)):
scheme = self.normalization_scheme_per_modality[c]
if scheme == "CT":
# clip to lb and ub from train data foreground and use foreground mn and sd from training data
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensityproperties[c]['mean']
std_intensity = self.intensityproperties[c]['sd']
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
data[c] = np.clip(data[c], lower_bound, upper_bound)
data[c] = (data[c] - mean_intensity) / std_intensity
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
elif scheme == "CT2":
# clip to lb and ub from train data foreground, use mn and sd form each case for normalization
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
mask = (data[c] > lower_bound) & (data[c] < upper_bound)
data[c] = np.clip(data[c], lower_bound, upper_bound)
mn = data[c][mask].mean()
sd = data[c][mask].std()
data[c] = (data[c] - mn) / sd
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
else:
if use_nonzero_mask[c]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
data[c][mask == 0] = 0
return data, seg, properties
class PreprocessorFor2D(GenericPreprocessor):
def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None):
super(PreprocessorFor2D, self).__init__(normalization_scheme_per_modality, use_nonzero_mask,
transpose_forward, intensityproperties)
def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
num_threads=default_num_threads, force_separate_z=None):
print("Initializing to run preprocessing")
print("npz folder:", input_folder_with_cropped_npz)
print("output_folder:", output_folder)
list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True)
assert len(list_of_cropped_npz_files) != 0, "set list of files first"
maybe_mkdir_p(output_folder)
all_args = []
num_stages = len(target_spacings)
# we need to know which classes are present in this dataset so that we can precompute where these classes are
# located. This is needed for oversampling foreground
all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes']
for i in range(num_stages):
output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i)
maybe_mkdir_p(output_folder_stage)
spacing = target_spacings[i]
for j, case in enumerate(list_of_cropped_npz_files):
case_identifier = get_case_identifier_from_npz(case)
args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes
all_args.append(args)
p = Pool(num_threads)
p.starmap(self._run_internal, all_args)
p.close()
p.join()
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
target_spacing[0] = original_spacing_transposed[0]
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
print("normalization...")
for c in range(len(data)):
scheme = self.normalization_scheme_per_modality[c]
if scheme == "CT":
# clip to lb and ub from train data foreground and use foreground mn and sd from training data
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensityproperties[c]['mean']
std_intensity = self.intensityproperties[c]['sd']
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
data[c] = np.clip(data[c], lower_bound, upper_bound)
data[c] = (data[c] - mean_intensity) / std_intensity
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
elif scheme == "CT2":
# clip to lb and ub from train data foreground, use mn and sd form each case for normalization
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
mask = (data[c] > lower_bound) & (data[c] < upper_bound)
data[c] = np.clip(data[c], lower_bound, upper_bound)
mn = data[c][mask].mean()
sd = data[c][mask].std()
data[c] = (data[c] - mn) / sd
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
else:
if use_nonzero_mask[c]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
data[c][mask == 0] = 0
print("normalization done")
return data, seg, properties
class PreprocessorFor3D_NoResampling(GenericPreprocessor):
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
"""
if target_spacing[0] is None or nan we use original_spacing_transposed[0] (no resampling along z)
:param data:
:param target_spacing:
:param properties:
:param seg:
:param force_separate_z:
:return:
"""
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
# remove nans
data[np.isnan(data)] = 0
target_spacing = deepcopy(original_spacing_transposed)
#print(target_spacing, original_spacing_transposed)
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
st = "before:" + str(before) + '\nafter' + str(after) + "\n"
print(st)
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
for c in range(len(data)):
scheme = self.normalization_scheme_per_modality[c]
if scheme == "CT":
# clip to lb and ub from train data foreground and use foreground mn and sd from training data
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
mean_intensity = self.intensityproperties[c]['mean']
std_intensity = self.intensityproperties[c]['sd']
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
data[c] = np.clip(data[c], lower_bound, upper_bound)
data[c] = (data[c] - mean_intensity) / std_intensity
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
elif scheme == "CT2":
# clip to lb and ub from train data foreground, use mn and sd form each case for normalization
assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
lower_bound = self.intensityproperties[c]['percentile_00_5']
upper_bound = self.intensityproperties[c]['percentile_99_5']
mask = (data[c] > lower_bound) & (data[c] < upper_bound)
data[c] = np.clip(data[c], lower_bound, upper_bound)
mn = data[c][mask].mean()
sd = data[c][mask].std()
data[c] = (data[c] - mn) / sd
if use_nonzero_mask[c]:
data[c][seg[-1] < 0] = 0
else:
if use_nonzero_mask[c]:
mask = seg[-1] >= 0
else:
mask = np.ones(seg.shape[1:], dtype=bool)
data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
data[c][mask == 0] = 0
return data, seg, properties
class PreprocessorFor2D_noNormalization(GenericPreprocessor):
def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
before = {
'spacing': properties["original_spacing"],
'spacing_transposed': original_spacing_transposed,
'data.shape (data is transposed)': data.shape
}
target_spacing[0] = original_spacing_transposed[0]
data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
after = {
'spacing': target_spacing,
'data.shape (data is resampled)': data.shape
}
print("before:", before, "\nafter: ", after, "\n")
if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
seg[seg < -1] = 0
properties["size_after_resampling"] = data[0].shape
properties["spacing_after_resampling"] = target_spacing
use_nonzero_mask = self.use_nonzero_mask
assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
"must have as many entries as data has " \
"modalities"
assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
" has modalities"
return data, seg, properties
================================================
FILE: unetr_pp/preprocessing/sanity_checks.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing import Pool
import SimpleITK as sitk
import nibabel as nib
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
def verify_all_same_orientation(folder):
"""
This should run after cropping
:param folder:
:return:
"""
nii_files = subfiles(folder, suffix=".nii.gz", join=True)
orientations = []
for n in nii_files:
img = nib.load(n)
affine = img.affine
orientation = nib.aff2axcodes(affine)
orientations.append(orientation)
# now we need to check whether they are all the same
orientations = np.array(orientations)
unique_orientations = np.unique(orientations, axis=0)
all_same = len(unique_orientations) == 1
return all_same, unique_orientations
def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image):
ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(), img_1.GetDirection(), img_1.GetSize()
ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(), img_2.GetDirection(), img_2.GetSize()
same_ori = np.all(np.isclose(ori1, ori2))
if not same_ori:
print("the origin does not match between the images:")
print(ori1)
print(ori2)
same_spac = np.all(np.isclose(spacing1, spacing2))
if not same_spac:
print("the spacing does not match between the images")
print(spacing1)
print(spacing2)
same_dir = np.all(np.isclose(direction1, direction2))
if not same_dir:
print("the direction does not match between the images")
print(direction1)
print(direction2)
same_size = np.all(np.isclose(size1, size2))
if not same_size:
print("the size does not match between the images")
print(size1)
print(size2)
if same_ori and same_spac and same_dir and same_size:
return True
else:
return False
def verify_contains_only_expected_labels(itk_img: str, valid_labels: (tuple, list)):
img_npy = sitk.GetArrayFromImage(sitk.ReadImage(itk_img))
uniques = np.unique(img_npy)
invalid_uniques = [i for i in uniques if i not in valid_labels]
if len(invalid_uniques) == 0:
r = True
else:
r = False
return r, invalid_uniques
def verify_dataset_integrity(folder):
"""
folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json
checks if all training cases and labels are present
checks if all test cases (if any) are present
for each case, checks whether all modalities apre present
for each case, checks whether the pixel grids are aligned
checks whether the labels really only contain values they should
:param folder:
:return:
"""
assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder
assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder
assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder
dataset = load_json(join(folder, "dataset.json"))
training_cases = dataset['training']
num_modalities = len(dataset['modality'].keys())
test_cases = dataset['test']
expected_train_identifiers = [i['image'].split("/")[-1][:-7] for i in training_cases] #image00?
expected_test_identifiers = [i.split("/")[-1][:-7] for i in test_cases]
## check training set
nii_files_in_imagesTr = subfiles((join(folder, "imagesTr")), suffix=".nii.gz", join=False)
nii_files_in_labelsTr = subfiles((join(folder, "labelsTr")), suffix=".nii.gz", join=False)
label_files = []
geometries_OK = True
has_nan = False
# check all cases #imageTr dont have duplicate image
if len(expected_train_identifiers) != len(np.unique(expected_train_identifiers)): raise RuntimeError("found duplicate training cases in dataset.json")
print("Verifying training set")
for c in expected_train_identifiers:
print("checking case", c)
# check if all files are present
expected_label_file = join(folder, "labelsTr", c + ".nii.gz")
label_files.append(expected_label_file)
expected_image_files = [join(folder, "imagesTr", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)]
assert isfile(expected_label_file), "could not find label file for case %s. Expected file: \n%s" % (
c, expected_label_file)
assert all([isfile(i) for i in
expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % (
c, expected_image_files)
# verify that all modalities and the label have the same shape and geometry.
label_itk = sitk.ReadImage(expected_label_file)
nans_in_seg = np.any(np.isnan(sitk.GetArrayFromImage(label_itk)))
has_nan = has_nan | nans_in_seg
if nans_in_seg:
print("There are NAN values in segmentation %s" % expected_label_file)
images_itk = [sitk.ReadImage(i) for i in expected_image_files]
for i, img in enumerate(images_itk):
nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img)))
has_nan = has_nan | nans_in_image
same_geometry = verify_same_geometry(img, label_itk)
if not same_geometry:
geometries_OK = False
print("The geometry of the image %s does not match the geometry of the label file. The pixel arrays "
"will not be aligned and nnU-Net cannot use this data. Please make sure your image modalities "
"are coregistered and have the same geometry as the label" % expected_image_files[0][:-12])
if nans_in_image:
print("There are NAN values in image %s" % expected_image_files[i])
# now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
for i in expected_image_files:
nii_files_in_imagesTr.remove(os.path.basename(i))
nii_files_in_labelsTr.remove(os.path.basename(expected_label_file))
# check for stragglers
assert len(
nii_files_in_imagesTr) == 0, "there are training cases in imagesTr that are not listed in dataset.json: %s" % nii_files_in_imagesTr
assert len(
nii_files_in_labelsTr) == 0, "there are training cases in labelsTr that are not listed in dataset.json: %s" % nii_files_in_labelsTr
# verify that only properly declared values are present in the labels
print("Verifying label values")
expected_labels = list(int(i) for i in dataset['labels'].keys())
# check if labels are in consecutive order
assert expected_labels[0] == 0, 'The first label must be 0 and maps to the background'
labels_valid_consecutive = np.ediff1d(expected_labels) == 1
assert all(labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'
p = Pool(default_num_threads)
results = p.starmap(verify_contains_only_expected_labels, zip(label_files, [expected_labels] * len(label_files)))
p.close()
p.join()
fail = False
print("Expected label values are", expected_labels)
for i, r in enumerate(results):
if not r[0]:
print("Unexpected labels found in file %s. Found these unexpected values (they should not be there) %s" % (
label_files[i], r[1]))
fail = True
if fail:
raise AssertionError(
"Found unexpected labels in the training dataset. Please correct that or adjust your dataset.json accordingly")
else:
print("Labels OK")
# check test set, but only if there actually is a test set
if len(expected_test_identifiers) > 0:
print("Verifying test set")
nii_files_in_imagesTs = subfiles((join(folder, "imagesTs")), suffix=".nii.gz", join=False)
for c in expected_test_identifiers:
# check if all files are present
expected_image_files = [join(folder, "imagesTs", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)]
assert all([isfile(i) for i in
expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % (
c, expected_image_files)
# verify that all modalities and the label have the same geometry. We use the affine for this
if num_modalities > 1:
images_itk = [sitk.ReadImage(i) for i in expected_image_files]
reference_img = images_itk[0]
for i, img in enumerate(images_itk[1:]):
assert verify_same_geometry(img, reference_img), "The modalities of the image %s do not seem to be " \
"registered. Please coregister your modalities." % (
expected_image_files[i])
# now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
for i in expected_image_files:
nii_files_in_imagesTs.remove(os.path.basename(i))
assert len(
nii_files_in_imagesTs) == 0, "there are training cases in imagesTs that are not listed in dataset.json: %s" % nii_files_in_imagesTr
all_same, unique_orientations = verify_all_same_orientation(join(folder, "imagesTr"))
if not all_same:
print(
"WARNING: Not all images in the dataset have the same axis ordering. We very strongly recommend you correct that by reorienting the data. fslreorient2std should do the trick")
# save unique orientations to dataset.json
if not geometries_OK:
raise Warning("GEOMETRY MISMATCH FOUND! CHECK THE TEXT OUTPUT! This does not cause an error at this point but you should definitely check whether your geometries are alright!")
else:
print("Dataset OK")
if has_nan:
raise RuntimeError("Some images have nan values in them. This will break the training. See text output above to see which ones")
def reorient_to_RAS(img_fname: str, output_fname: str = None):
img = nib.load(img_fname)
canonical_img = nib.as_closest_canonical(img)
if output_fname is None:
output_fname = img_fname
nib.save(canonical_img, output_fname)
if __name__ == "__main__":
# investigate geometry issues
import SimpleITK as sitk
# load image
gt_itk = sitk.ReadImage(
"/media/fabian/Results/nnFormer/3d_fullres/Task064_KiTS_labelsFixed/nnFormerTrainerV2__nnFormerPlansv2.1/gt_niftis/case_00085.nii.gz")
# get numpy array
pred_npy = sitk.GetArrayFromImage(gt_itk)
# create new image from numpy array
prek_itk_new = sitk.GetImageFromArray(pred_npy)
# copy geometry
prek_itk_new.CopyInformation(gt_itk)
# prek_itk_new = copy_geometry(prek_itk_new, gt_itk)
# save
sitk.WriteImage(prek_itk_new, "test.mnc")
# load images in nib
gt = nib.load(
"/media/fabian/Results/nnFormer/3d_fullres/Task064_KiTS_labelsFixed/nnFormerTrainerV2__nnFormerPlansv2.1/gt_niftis/case_00085.nii.gz")
pred_nib = nib.load("test.mnc")
new_img_sitk = sitk.ReadImage("test.mnc")
np1 = sitk.GetArrayFromImage(gt_itk)
np2 = sitk.GetArrayFromImage(prek_itk_new)
================================================
FILE: unetr_pp/run/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/run/default_configuration.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unetr_pp
from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.experiment_planning.summarize_plans import summarize_plans
from unetr_pp.training.model_restore import recursive_find_python_class
import numpy as np
import pickle
def get_configuration_from_output_folder(folder):
# split off network_training_output_dir
folder = folder[len(network_training_output_dir):]
if folder.startswith("/"):
folder = folder[1:]
configuration, task, trainer_and_plans_identifier = folder.split("/")
trainer, plans_identifier = trainer_and_plans_identifier.split("__")
return configuration, task, trainer, plans_identifier
def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier,
search_in=(unetr_pp.__path__[0], "training", "network_training"),
base_module='unetr_pp.training.network_training'):
assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \
"network can only be one of the following: \'3d\', \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'"
dataset_directory = join(preprocessing_output_dir, task)
if network == '2d':
plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl")
else:
plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl")
plans = load_pickle(plans_file)
# Maybe have two kinds of plans,choose the later one
if len(plans['plans_per_stage']) == 2:
Stage = 1
else:
Stage = 0
if task == 'Task001_ACDC':
plans['plans_per_stage'][Stage]['batch_size'] = 4
plans['plans_per_stage'][Stage]['patch_size'] = np.array([16, 160, 160])
pickle_file = open(plans_file, 'wb')
pickle.dump(plans, pickle_file)
pickle_file.close()
elif task == 'Task002_Synapse':
plans['plans_per_stage'][Stage]['batch_size'] = 2
plans['plans_per_stage'][Stage]['patch_size'] = np.array([64, 128, 128])
plans['plans_per_stage'][Stage]['pool_op_kernel_sizes'] = [[2, 2, 2], [2, 2, 2],
[2, 2, 2]] # for deep supervision
pickle_file = open(plans_file, 'wb')
pickle.dump(plans, pickle_file)
pickle_file.close()
possible_stages = list(plans['plans_per_stage'].keys())
if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1:
raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does "
"not require the cascade. Run 3d_fullres instead")
if network == '2d' or network == "3d_lowres":
stage = 0
else:
stage = possible_stages[-1]
trainer_class = recursive_find_python_class([join(*search_in)], network_trainer,
current_module=base_module)
output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier)
print("###############################################")
print("I am running the following nnFormer: %s" % network)
print("My trainer class is: ", trainer_class)
print("For that I will be using the following configuration:")
summarize_plans(plans_file)
print("I am using stage %d from these plans" % stage)
if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres':
batch_dice = True
print("I am using batch dice + CE loss")
else:
batch_dice = False
print("I am using sample dice + CE loss")
print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier']))
print("###############################################")
return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class
================================================
FILE: unetr_pp/run/run_training.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.run.default_configuration import get_default_configuration
from unetr_pp.paths import default_plans_identifier
from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name
import numpy as np
import torch
seed = 42
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# to improve the efficiency set the last two true
def main():
parser = argparse.ArgumentParser()
parser.add_argument("network")
parser.add_argument("network_trainer")
parser.add_argument("task", help="can be task name or task id")
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
action="store_true")
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
action="store_true")
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
default=default_plans_identifier, required=False)
parser.add_argument("--use_compressed_data", default=False, action="store_true",
help="If you set use_compressed_data, the training cases will not be decompressed. "
"Reading compressed data is much more CPU and RAM intensive and should only be used if "
"you know what you are doing", required=False)
parser.add_argument("--deterministic",
help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
"this is not necessary. Deterministic training will make you overfit to some random seed. "
"Don't use that.",
required=False, default=False, action="store_true")
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then model will "
"export npz files of "
"predicted segmentations "
"in the validation as well. "
"This is needed to run the "
"ensembling (if needed)")
parser.add_argument("--find_lr", required=False, default=False, action="store_true",
help="not used here, just for fun")
parser.add_argument("--valbest", required=False, default=False, action="store_true",
help="hands off. This is not intended to be used")
parser.add_argument("--fp32", required=False, default=True, action="store_true",
help="disable mixed precision training and run old school fp32")
parser.add_argument("--val_folder", required=False, default="validation_raw",
help="name of the validation folder. No need to use this for most people")
parser.add_argument("--disable_saving", required=False, action='store_true',
help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
"will be removed at the end of the training). Useful for development when you are "
"only interested in the results and want to save some disk space")
parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
"closely observing the model performance on specific configurations. You do not need it "
"when applying nnU-Net because the postprocessing for this will be determined only once "
"Usually running postprocessing on each fold is computationally cheap, but some users have "
"reported issues with very large images. If your images are large (>600x600x600 voxels) "
"you should consider setting this flag.")
parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
help='Validation does not overwrite existing segmentations')
parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
help='do not predict next stage')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
'file, for example model_final_checkpoint.model). Will only be used when actually training. '
'Optional. Beta. Use with caution.')
args = parser.parse_args()
task = args.task
fold = args.fold
network = args.network
network_trainer = args.network_trainer
validation_only = args.validation_only
plans_identifier = args.p
find_lr = args.find_lr
disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
use_compressed_data = args.use_compressed_data
decompress_data = not use_compressed_data
deterministic = args.deterministic
valbest = args.valbest
fp32 = args.fp32
run_mixed_precision = not fp32
val_folder = args.val_folder
if not task.startswith("Task"):
task_id = int(task)
task = convert_id_to_task_name(task_id)
if fold == 'all':
pass
else:
fold = int(fold)
plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
if trainer_class is None:
raise RuntimeError("Could not find trainer class in unetr_pp.training.network_training")
trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
deterministic=deterministic, fp16=run_mixed_precision)
if args.disable_saving:
trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to
# self.best_val_eval_criterion_MA
trainer.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest. We need that in case
# the training chashes
trainer.save_latest_only = True # if false it will not store/overwrite _latest but separate files each
trainer.initialize(not validation_only)
if find_lr:
trainer.find_lr()
else:
if not validation_only:
if args.continue_training:
# -c was set, continue a previous training and ignore pretrained weights
trainer.load_latest_checkpoint()
else:
# new training without pretraine weights, do nothing
pass
trainer.run_training()
else:
if valbest:
trainer.load_best_checkpoint(train=False)
else:
trainer.load_final_checkpoint(train=False)
trainer.network.eval()
# predict validation
trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
run_postprocessing_on_folds=not disable_postprocessing_on_folds,
overwrite=args.val_disable_overwrite)
if __name__ == "__main__":
main()
================================================
FILE: unetr_pp/training/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/training/cascade_stuff/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/training/cascade_stuff/predict_next_stage.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
import argparse
from unetr_pp.preprocessing.preprocessing import resample_data_or_seg
from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p
import unetr_pp
from unetr_pp.run.default_configuration import get_default_configuration
from multiprocessing import Pool
from unetr_pp.training.model_restore import recursive_find_python_class
from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc
def resample_and_save(predicted, target_shape, output_file, force_separate_z=False,
interpolation_order=1, interpolation_order_z=0):
if isinstance(predicted, str):
assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \
"isfile(segmentation_softmax) must be True"
del_file = deepcopy(predicted)
predicted = np.load(predicted)
os.remove(del_file)
predicted_new_shape = resample_data_or_seg(predicted, target_shape, False, order=interpolation_order,
do_separate_z=force_separate_z, cval=0, order_z=interpolation_order_z)
seg_new_shape = predicted_new_shape.argmax(0)
np.savez_compressed(output_file, data=seg_new_shape.astype(np.uint8))
def predict_next_stage(trainer, stage_to_be_predicted_folder):
output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
maybe_mkdir_p(output_folder)
if 'segmentation_export_params' in trainer.plans.keys():
force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
export_pool = Pool(2)
results = []
for pat in trainer.dataset_val.keys():
print(pat)
data_file = trainer.dataset_val[pat]['data_file']
data_preprocessed = np.load(data_file)['data'][:-1]
predicted_probabilities = trainer.predict_preprocessed_data_return_seg_and_softmax(
data_preprocessed, do_mirroring=trainer.data_aug_params["do_mirror"],
mirror_axes=trainer.data_aug_params['mirror_axes'], mixed_precision=trainer.fp16)[1]
data_file_nofolder = data_file.split("/")[-1]
data_file_nextstage = join(stage_to_be_predicted_folder, data_file_nofolder)
data_nextstage = np.load(data_file_nextstage)['data']
target_shp = data_nextstage.shape[1:]
output_file = join(output_folder, data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz")
if np.prod(predicted_probabilities.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(output_file[:-4] + ".npy", predicted_probabilities)
predicted_probabilities = output_file[:-4] + ".npy"
results.append(export_pool.starmap_async(resample_and_save, [(predicted_probabilities, target_shp, output_file,
force_separate_z, interpolation_order,
interpolation_order_z)]))
_ = [i.get() for i in results]
export_pool.close()
export_pool.join()
if __name__ == "__main__":
"""
RUNNING THIS SCRIPT MANUALLY IS USUALLY NOT NECESSARY. USE THE run_training.py FILE!
This script is intended for predicting all the low resolution predictions of 3d_lowres for the next stage of the
cascade. It needs to run once for each fold so that the segmentation is only generated for the validation set
and not on the data the network was trained on. Run it with
python predict_next_stage TRAINERCLASS TASK FOLD"""
parser = argparse.ArgumentParser()
parser.add_argument("network_trainer")
parser.add_argument("task")
parser.add_argument("fold", type=int)
args = parser.parse_args()
trainerclass = args.network_trainer
task = args.task
fold = args.fold
plans_file, folder_with_preprocessed_data, output_folder_name, dataset_directory, batch_dice, stage = \
get_default_configuration("3d_lowres", task)
trainer_class = recursive_find_python_class([join(unetr_pp.__path__[0], "training", "network_training")],
trainerclass,
"unetr_pp.training.network_training")
if trainer_class is None:
raise RuntimeError("Could not find trainer class in unetr_pp.training.network_training")
else:
assert issubclass(trainer_class,
Trainer_acdc), "network_trainer was found but is not derived from nnFormerTrainer"
trainer = trainer_class(plans_file, fold, folder_with_preprocessed_data, output_folder=output_folder_name,
dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage)
trainer.initialize(False)
trainer.load_dataset()
trainer.do_split()
trainer.load_best_checkpoint(train=False)
stage_to_be_predicted_folder = join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1)
output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
maybe_mkdir_p(output_folder)
predict_next_stage(trainer, stage_to_be_predicted_folder)
================================================
FILE: unetr_pp/training/data_augmentation/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/training/data_augmentation/custom_transforms.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from batchgenerators.transforms import AbstractTransform
class RemoveKeyTransform(AbstractTransform):
def __init__(self, key_to_remove):
self.key_to_remove = key_to_remove
def __call__(self, **data_dict):
_ = data_dict.pop(self.key_to_remove, None)
return data_dict
class MaskTransform(AbstractTransform):
def __init__(self, dct_for_where_it_was_used, mask_idx_in_seg=1, set_outside_to=0, data_key="data", seg_key="seg"):
"""
data[mask < 0] = 0
Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!
:param dct_for_where_it_was_used:
:param mask_idx_in_seg:
:param set_outside_to:
:param data_key:
:param seg_key:
"""
self.dct_for_where_it_was_used = dct_for_where_it_was_used
self.seg_key = seg_key
self.data_key = data_key
self.set_outside_to = set_outside_to
self.mask_idx_in_seg = mask_idx_in_seg
def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
if seg is None or seg.shape[1] < self.mask_idx_in_seg:
raise Warning("mask not found, seg may be missing or seg[:, mask_idx_in_seg] may not exist")
data = data_dict.get(self.data_key)
for b in range(data.shape[0]):
mask = seg[b, self.mask_idx_in_seg]
for c in range(data.shape[1]):
if self.dct_for_where_it_was_used[c]:
data[b, c][mask < 0] = self.set_outside_to
data_dict[self.data_key] = data
return data_dict
def convert_3d_to_2d_generator(data_dict):
shp = data_dict['data'].shape
data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
data_dict['orig_shape_data'] = shp
shp = data_dict['seg'].shape
data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
data_dict['orig_shape_seg'] = shp
return data_dict
def convert_2d_to_3d_generator(data_dict):
shp = data_dict['orig_shape_data']
current_shape = data_dict['data'].shape
data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], current_shape[-2], current_shape[-1]))
shp = data_dict['orig_shape_seg']
current_shape_seg = data_dict['seg'].shape
data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1], shp[2], current_shape_seg[-2], current_shape_seg[-1]))
return data_dict
class Convert3DTo2DTransform(AbstractTransform):
def __init__(self):
pass
def __call__(self, **data_dict):
return convert_3d_to_2d_generator(data_dict)
class Convert2DTo3DTransform(AbstractTransform):
def __init__(self):
pass
def __call__(self, **data_dict):
return convert_2d_to_3d_generator(data_dict)
class ConvertSegmentationToRegionsTransform(AbstractTransform):
def __init__(self, regions: dict, seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
"""
regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, example:
regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
:param regions:
:param seg_key:
:param output_key:
"""
self.seg_channel = seg_channel
self.output_key = output_key
self.seg_key = seg_key
self.regions = regions
def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
num_regions = len(self.regions)
if seg is not None:
seg_shp = seg.shape
output_shape = list(seg_shp)
output_shape[1] = num_regions
region_output = np.zeros(output_shape, dtype=seg.dtype)
for b in range(seg_shp[0]):
for r, k in enumerate(self.regions.keys()):
for l in self.regions[k]:
region_output[b, r][seg[b, self.seg_channel] == l] = 1
data_dict[self.output_key] = region_output
return data_dict
================================================
FILE: unetr_pp/training/data_augmentation/data_augmentation_insaneDA.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \
GammaTransform, MirrorTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
ContrastAugmentationTransform, BrightnessTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
MaskTransform, ConvertSegmentationToRegionsTransform
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
ApplyRandomBinaryOperatorTransform, \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1,
seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
soft_ds=False,
classes=None, pin_memory=True, regions=None):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform(
patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
border_mode_seg="constant", border_cval_seg=border_val_seg,
order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"),
p_independent_scale_per_axis=params.get("p_independent_scale_per_axis")
))
if params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15))
tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
p_per_channel=0.5,
order_downsample=0, order_upsample=3, p_per_sample=0.25,
ignore_axes=ignore_axes))
tr_transforms.append(
GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=0.15)) # inverted gamma
if params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
params.get("additive_brightness_sigma"),
True, p_per_sample=params.get("additive_brightness_p_per_sample"),
p_per_channel=params.get("additive_brightness_p_per_channel")))
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror") or params.get("mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
if params.get("cascade_do_cascade_augmentations") and not None and params.get(
"cascade_do_cascade_augmentations"):
if params.get("cascade_random_binary_transform_p") > 0:
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
p_per_sample=params.get("cascade_random_binary_transform_p"),
key="data",
strel_size=params.get("cascade_random_binary_transform_size")))
if params.get("cascade_remove_conn_comp_p") > 0:
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
key="data",
p_per_sample=params.get("cascade_remove_conn_comp_p"),
fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
dont_do_if_covers_more_than_X_percent=params.get(
"cascade_remove_conn_comp_fill_with_other_class_p")))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=seeds_train, pin_memory=pin_memory)
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
return batchgenerator_train, batchgenerator_val
================================================
FILE: unetr_pp/training/data_augmentation/data_augmentation_insaneDA2.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, \
GammaTransform, MirrorTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
ContrastAugmentationTransform, BrightnessTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
MaskTransform, ConvertSegmentationToRegionsTransform
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
ApplyRandomBinaryOperatorTransform, \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform_2
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
def get_insaneDA_augmentation2(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1,
seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
soft_ds=False,
classes=None, pin_memory=True, regions=None):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform_2(
patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
deformation_scale=params.get("eldef_deformation_scale"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
border_mode_seg="constant", border_cval_seg=border_val_seg,
order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"),
p_independent_scale_per_axis=params.get("p_independent_scale_per_axis")
))
if params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15))
tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
p_per_channel=0.5,
order_downsample=0, order_upsample=3, p_per_sample=0.25,
ignore_axes=ignore_axes))
tr_transforms.append(
GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=0.15)) # inverted gamma
if params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
params.get("additive_brightness_sigma"),
True, p_per_sample=params.get("additive_brightness_p_per_sample"),
p_per_channel=params.get("additive_brightness_p_per_channel")))
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror") or params.get("mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
if params.get("cascade_do_cascade_augmentations") and not None and params.get(
"cascade_do_cascade_augmentations"):
if params.get("cascade_random_binary_transform_p") > 0:
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
p_per_sample=params.get("cascade_random_binary_transform_p"),
key="data",
strel_size=params.get("cascade_random_binary_transform_size")))
if params.get("cascade_remove_conn_comp_p") > 0:
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
key="data",
p_per_sample=params.get("cascade_remove_conn_comp_p"),
fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
dont_do_if_covers_more_than_X_percent=params.get(
"cascade_remove_conn_comp_fill_with_other_class_p")))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=seeds_train, pin_memory=pin_memory)
#batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
return batchgenerator_train, batchgenerator_val
================================================
FILE: unetr_pp/training/data_augmentation/data_augmentation_moreDA.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \
GammaTransform, MirrorTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
ContrastAugmentationTransform, BrightnessTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
MaskTransform, ConvertSegmentationToRegionsTransform
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
ApplyRandomBinaryOperatorTransform, \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1,
seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
soft_ds=False,
classes=None, pin_memory=True, regions=None,
use_nondetMultiThreadedAugmenter: bool = False):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
else:
ignore_axes = None
tr_transforms.append(SpatialTransform(
patch_size, patch_center_dist_from_border=None,
do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"),
sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"),
do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
border_mode_seg="constant", border_cval_seg=border_val_seg,
order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
))
if params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
if params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
params.get("additive_brightness_sigma"),
True, p_per_sample=params.get("additive_brightness_p_per_sample"),
p_per_channel=params.get("additive_brightness_p_per_channel")))
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
p_per_channel=0.5,
order_downsample=0, order_upsample=3, p_per_sample=0.25,
ignore_axes=ignore_axes))
tr_transforms.append(
GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=0.1)) # inverted gamma
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror") or params.get("mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
if params.get("cascade_do_cascade_augmentations") is not None and params.get(
"cascade_do_cascade_augmentations"):
if params.get("cascade_random_binary_transform_p") > 0:
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
p_per_sample=params.get("cascade_random_binary_transform_p"),
key="data",
strel_size=params.get("cascade_random_binary_transform_size"),
p_per_label=params.get("cascade_random_binary_transform_p_per_label")))
if params.get("cascade_remove_conn_comp_p") > 0:
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
key="data",
p_per_sample=params.get("cascade_remove_conn_comp_p"),
fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
dont_do_if_covers_more_than_X_percent=params.get(
"cascade_remove_conn_comp_fill_with_other_class_p")))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
if use_nondetMultiThreadedAugmenter:
if NonDetMultiThreadedAugmenter is None:
raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
batchgenerator_train = NonDetMultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"), seeds=seeds_train,
pin_memory=pin_memory)
else:
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=seeds_train, pin_memory=pin_memory)
# batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
# import IPython;IPython.embed()
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
if use_nondetMultiThreadedAugmenter:
if NonDetMultiThreadedAugmenter is None:
raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
batchgenerator_val = NonDetMultiThreadedAugmenter(dataloader_val, val_transforms,
max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
else:
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms,
max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
# batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
return batchgenerator_train, batchgenerator_val
================================================
FILE: unetr_pp/training/data_augmentation/data_augmentation_noDA.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, Compose
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from unetr_pp.training.data_augmentation.custom_transforms import ConvertSegmentationToRegionsTransform
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params,
deep_supervision_scales=None, soft_ds=False,
classes=None, pin_memory=True, regions=None):
"""
use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
"""
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=range(params.get('num_threads')), pin_memory=pin_memory)
batchgenerator_train.restart()
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
if deep_supervision_scales is not None:
if soft_ds:
assert classes is not None
val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
else:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=range(max(params.get('num_threads') // 2, 1)),
pin_memory=pin_memory)
batchgenerator_val.restart()
return batchgenerator_train, batchgenerator_val
================================================
FILE: unetr_pp/training/data_augmentation/default_data_augmentation.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
import numpy as np
from batchgenerators.dataloading import MultiThreadedAugmenter
from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \
GammaTransform, MirrorTransform, Compose
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
MaskTransform, ConvertSegmentationToRegionsTransform
from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
ApplyRandomBinaryOperatorTransform, \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
try:
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
except ImportError as ie:
NonDetMultiThreadedAugmenter = None
default_3D_augmentation_params = {
"selected_data_channels": None,
"selected_seg_channels": None,
"do_elastic": True,
"elastic_deform_alpha": (0., 900.),
"elastic_deform_sigma": (9., 13.),
"p_eldef": 0.2,
"do_scaling": True,
"scale_range": (0.85, 1.25),
"independent_scale_factor_for_each_axis": False,
"p_independent_scale_per_axis": 1,
"p_scale": 0.2,
"do_rotation": True,
"rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_p_per_axis": 1,
"p_rot": 0.2,
"random_crop": False,
"random_crop_dist_to_border": None,
"do_gamma": True,
"gamma_retain_stats": True,
"gamma_range": (0.7, 1.5),
"p_gamma": 0.3,
"do_mirror": True,
"mirror_axes": (0, 1, 2),
"dummy_2D": False,
"mask_was_used_for_normalization": None,
"border_mode_data": "constant",
"all_segmentation_labels": None, # used for cascade
"move_last_seg_chanel_to_data": False, # used for cascade
"cascade_do_cascade_augmentations": False, # used for cascade
"cascade_random_binary_transform_p": 0.4,
"cascade_random_binary_transform_p_per_label": 1,
"cascade_random_binary_transform_size": (1, 8),
"cascade_remove_conn_comp_p": 0.2,
"cascade_remove_conn_comp_max_size_percent_threshold": 0.15,
"cascade_remove_conn_comp_fill_with_other_class_p": 0.0,
"do_additive_brightness": False,
"additive_brightness_p_per_sample": 0.15,
"additive_brightness_p_per_channel": 0.5,
"additive_brightness_mu": 0.0,
"additive_brightness_sigma": 0.1,
"num_threads": 12 if 'nnFormer_n_proc_DA' not in os.environ else int(os.environ['nnFormer_n_proc_DA']),
"num_cached_per_thread": 1,
}
default_2D_augmentation_params = deepcopy(default_3D_augmentation_params)
default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.)
default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.)
default_2D_augmentation_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
default_2D_augmentation_params["rotation_y"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
default_2D_augmentation_params["rotation_z"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
# sometimes you have 3d data and a 3d net but cannot augment them properly in 3d due to anisotropy (which is currently
# not supported in batchgenerators). In that case you can 'cheat' and transfer your 3d data into 2d data and
# transform them back after augmentation
default_2D_augmentation_params["dummy_2D"] = False
default_2D_augmentation_params["mirror_axes"] = (0, 1) # this can be (0, 1, 2) if dummy_2D=True
def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
coords = np.array(final_patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(int)
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1, pin_memory=True,
seeds_train=None, seeds_val=None, regions=None):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
# don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
tr_transforms.append(Convert3DTo2DTransform())
tr_transforms.append(SpatialTransform(
patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
border_cval_seg=border_val_seg,
order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
))
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
tr_transforms.append(Convert2DTo3DTransform())
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
if params.get("cascade_do_cascade_augmentations") and not None and params.get(
"cascade_do_cascade_augmentations"):
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
p_per_sample=params.get("cascade_random_binary_transform_p"),
key="data",
strel_size=params.get("cascade_random_binary_transform_size")))
tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
key="data",
p_per_sample=params.get("cascade_remove_conn_comp_p"),
fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p")))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
# from batchgenerators.dataloading import SingleThreadedAugmenter
# batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
# import IPython;IPython.embed()
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"), seeds=seeds_train,
pin_memory=pin_memory)
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
# batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"), seeds=seeds_val,
pin_memory=pin_memory)
return batchgenerator_train, batchgenerator_val
if __name__ == "__main__":
from unetr_pp.training.dataloading.dataset_loading import DataLoader3D, load_dataset
from unetr_pp.paths import preprocessing_output_dir
import os
import pickle
t = "Task002_Heart"
p = os.path.join(preprocessing_output_dir, t)
dataset = load_dataset(p, 0)
with open(os.path.join(p, "plans.pkl"), 'rb') as f:
plans = pickle.load(f)
basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size),
default_3D_augmentation_params['rotation_x'],
default_3D_augmentation_params['rotation_y'],
default_3D_augmentation_params['rotation_z'],
default_3D_augmentation_params['scale_range'])
dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1)
tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int))
================================================
FILE: unetr_pp/training/data_augmentation/downsampling.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding_batched, resize_segmentation
from batchgenerators.transforms import AbstractTransform
from torch.nn.functional import avg_pool2d, avg_pool3d
import numpy as np
class DownsampleSegForDSTransform3(AbstractTransform):
'''
returns one hot encodings of the segmentation maps if downsampling has occured (no one hot for highest resolution)
downsampled segmentations are smooth, not 0/1
returns torch tensors, not numpy arrays!
always uses seg channel 0!!
you should always give classes! Otherwise weird stuff may happen
'''
def __init__(self, ds_scales=(1, 0.5, 0.25), input_key="seg", output_key="seg", classes=None):
self.classes = classes
self.output_key = output_key
self.input_key = input_key
self.ds_scales = ds_scales
def __call__(self, **data_dict):
data_dict[self.output_key] = downsample_seg_for_ds_transform3(data_dict[self.input_key][:, 0], self.ds_scales, self.classes)
return data_dict
def downsample_seg_for_ds_transform3(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), classes=None):
output = []
one_hot = torch.from_numpy(convert_seg_image_to_one_hot_encoding_batched(seg, classes)) # b, c,
for s in ds_scales:
if all([i == 1 for i in s]):
output.append(torch.from_numpy(seg))
else:
kernel_size = tuple(int(1 / i) for i in s)
stride = kernel_size
pad = tuple((i-1) // 2 for i in kernel_size)
if len(s) == 2:
pool_op = avg_pool2d
elif len(s) == 3:
pool_op = avg_pool3d
else:
raise RuntimeError()
pooled = pool_op(one_hot, kernel_size, stride, pad, count_include_pad=False, ceil_mode=False)
output.append(pooled)
return output
class DownsampleSegForDSTransform2(AbstractTransform):
'''
data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
'''
def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, cval=0, input_key="seg", output_key="seg", axes=None):
self.axes = axes
self.output_key = output_key
self.input_key = input_key
self.cval = cval
self.order = order
self.ds_scales = ds_scales
def __call__(self, **data_dict):
data_dict[self.output_key] = downsample_seg_for_ds_transform2(data_dict[self.input_key], self.ds_scales, self.order,
self.cval, self.axes)
return data_dict
def downsample_seg_for_ds_transform2(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, cval=0, axes=None):
if axes is None:
axes = list(range(2, len(seg.shape)))
output = []
for s in ds_scales:
if all([i == 1 for i in s]):
output.append(seg)
else:
new_shape = np.array(seg.shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=seg.dtype)
for b in range(seg.shape[0]):
for c in range(seg.shape[1]):
out_seg[b, c] = resize_segmentation(seg[b, c], new_shape[2:], order, cval)
output.append(out_seg)
return output
================================================
FILE: unetr_pp/training/data_augmentation/pyramid_augmentations.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from skimage.morphology import label, ball
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
import numpy as np
from batchgenerators.transforms import AbstractTransform
class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25,
dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1):
"""
:param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25\%!
:param channel_idx: can be list or int
:param key:
"""
self.p_per_label = p_per_label
self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent
self.fill_with_other_class_p = fill_with_other_class_p
self.p_per_sample = p_per_sample
self.key = key
if not isinstance(channel_idx, (list, tuple)):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
for c in self.channel_idx:
if np.random.uniform() < self.p_per_label:
workon = np.copy(data[b, c])
num_voxels = np.prod(workon.shape, dtype=np.uint64)
lab, num_comp = label(workon, return_num=True)
if num_comp > 0:
component_ids = []
component_sizes = []
for i in range(1, num_comp + 1):
component_ids.append(i)
component_sizes.append(np.sum(lab == i))
component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent]
#_ = component_ids.pop(np.argmax(component_sizes))
#else:
# component_ids = list(range(1, num_comp + 1))
if len(component_ids) > 0:
random_component = np.random.choice(component_ids)
data[b, c][lab == random_component] = 0
if np.random.uniform() < self.fill_with_other_class_p:
other_ch = [i for i in self.channel_idx if i != c]
if len(other_ch) > 0:
other_class = np.random.choice(other_ch)
data[b, other_class][lab == random_component] = 1
data_dict[self.key] = data
return data_dict
class MoveSegAsOneHotToData(AbstractTransform):
def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True):
self.remove_from_origin = remove_from_origin
self.all_seg_labels = all_seg_labels
self.key_target = key_target
self.key_origin = key_origin
self.channel_id = channel_id
def __call__(self, **data_dict):
origin = data_dict.get(self.key_origin)
target = data_dict.get(self.key_target)
seg = origin[:, self.channel_id:self.channel_id+1]
seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype)
for i, l in enumerate(self.all_seg_labels):
seg_onehot[:, i][seg[:, 0] == l] = 1
target = np.concatenate((target, seg_onehot), 1)
data_dict[self.key_target] = target
if self.remove_from_origin:
remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id]
origin = origin[:, remaining_channels]
data_dict[self.key_origin] = origin
return data_dict
class ApplyRandomBinaryOperatorTransform(AbstractTransform):
def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing,
binary_opening),
key="data", strel_size=(1, 10), p_per_label=1):
self.p_per_label = p_per_label
self.strel_size = strel_size
self.key = key
self.any_of_these = any_of_these
self.p_per_sample = p_per_sample
assert not isinstance(channel_idx, tuple), "bäh"
if not isinstance(channel_idx, list):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
ch = deepcopy(self.channel_idx)
np.random.shuffle(ch)
for c in ch:
if np.random.uniform() < self.p_per_label:
operation = np.random.choice(self.any_of_these)
selem = ball(np.random.uniform(*self.strel_size))
workon = np.copy(data[b, c]).astype(int)
res = operation(workon, selem).astype(workon.dtype)
data[b, c] = res
# if class was added, we need to remove it in ALL other channels to keep one hot encoding
# properties
# we modify data
other_ch = [i for i in ch if i != c]
if len(other_ch) > 0:
was_added_mask = (res - workon) > 0
for oc in other_ch:
data[b, oc][was_added_mask] = 0
# if class was removed, leave it at background
data_dict[self.key] = data
return data_dict
class ApplyRandomBinaryOperatorTransform2(AbstractTransform):
def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing),
key="data", strel_size=(1, 10)):
"""
2019_11_22: I have no idea what the purpose of this was...
the same as above but here we should use only expanding operations. Expansions will replace other labels
:param channel_idx: can be list or int
:param p_per_sample:
:param any_of_these:
:param fill_diff_with_other_class:
:param key:
:param strel_size:
"""
self.strel_size = strel_size
self.key = key
self.any_of_these = any_of_these
self.p_per_sample = p_per_sample
self.p_per_label = p_per_label
assert not isinstance(channel_idx, tuple), "bäh"
if not isinstance(channel_idx, list):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
ch = deepcopy(self.channel_idx)
np.random.shuffle(ch)
for c in ch:
if np.random.uniform() < self.p_per_label:
operation = np.random.choice(self.any_of_these)
selem = ball(np.random.uniform(*self.strel_size))
workon = np.copy(data[b, c]).astype(int)
res = operation(workon, selem).astype(workon.dtype)
data[b, c] = res
# if class was added, we need to remove it in ALL other channels to keep one hot encoding
# properties
# we modify data
other_ch = [i for i in ch if i != c]
if len(other_ch) > 0:
was_added_mask = (res - workon) > 0
for oc in other_ch:
data[b, oc][was_added_mask] = 0
# if class was removed, leave it at backgound
data_dict[self.key] = data
return data_dict
================================================
FILE: unetr_pp/training/dataloading/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/training/dataloading/dataset_loading.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from batchgenerators.augmentations.utils import random_crop_2D_image_batched, pad_nd_image
import numpy as np
from batchgenerators.dataloading import SlimDataLoaderBase
from multiprocessing import Pool
from unetr_pp.configuration import default_num_threads
from unetr_pp.paths import preprocessing_output_dir
from batchgenerators.utilities.file_and_folder_operations import *
def get_case_identifiers(folder):
case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)]
return case_identifiers
def get_case_identifiers_from_raw_folder(folder):
case_identifiers = np.unique(
[i[:-12] for i in os.listdir(folder) if i.endswith(".nii.gz") and (i.find("segFromPrevStage") == -1)])
return case_identifiers
def convert_to_npy(args):
if not isinstance(args, tuple):
key = "data"
npz_file = args
else:
npz_file, key = args
if not isfile(npz_file[:-3] + "npy"):
a = np.load(npz_file)[key]
np.save(npz_file[:-3] + "npy", a)
def save_as_npz(args):
if not isinstance(args, tuple):
key = "data"
npy_file = args
else:
npy_file, key = args
d = np.load(npy_file)
np.savez_compressed(npy_file[:-3] + "npz", **{key: d})
def unpack_dataset(folder, threads=default_num_threads, key="data"):
"""
unpacks all npz files in a folder to npy (whatever you want to have unpacked must be saved unter key)
:param folder:
:param threads:
:param key:
:return:
"""
p = Pool(threads)
npz_files = subfiles(folder, True, None, ".npz", True)
p.map(convert_to_npy, zip(npz_files, [key] * len(npz_files)))
p.close()
p.join()
def pack_dataset(folder, threads=default_num_threads, key="data"):
p = Pool(threads)
npy_files = subfiles(folder, True, None, ".npy", True)
p.map(save_as_npz, zip(npy_files, [key] * len(npy_files)))
p.close()
p.join()
def delete_npy(folder):
case_identifiers = get_case_identifiers(folder)
npy_files = [join(folder, i + ".npy") for i in case_identifiers]
npy_files = [i for i in npy_files if isfile(i)]
for n in npy_files:
os.remove(n)
def load_dataset(folder, num_cases_properties_loading_threshold=1000):
# we don't load the actual data but instead return the filename to the np file.
print('loading dataset')
case_identifiers = get_case_identifiers(folder)
case_identifiers.sort()
dataset = OrderedDict()
for c in case_identifiers:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = join(folder, "%s.npz" % c)
# dataset[c]['properties'] = load_pickle(join(folder, "%s.pkl" % c))
dataset[c]['properties_file'] = join(folder, "%s.pkl" % c)
if dataset[c].get('seg_from_prev_stage_file') is not None:
dataset[c]['seg_from_prev_stage_file'] = join(folder, "%s_segs.npz" % c)
if len(case_identifiers) <= num_cases_properties_loading_threshold:
print('loading all case properties')
for i in dataset.keys():
dataset[i]['properties'] = load_pickle(dataset[i]['properties_file'])
return dataset
def crop_2D_image_force_fg(img, crop_size, valid_voxels):
"""
img must be [c, x, y]
img[-1] must be the segmentation with segmentation>0 being foreground
:param img:
:param crop_size:
:param valid_voxels: voxels belonging to the selected class
:return:
"""
assert len(valid_voxels.shape) == 2
if type(crop_size) not in (tuple, list):
crop_size = [crop_size] * (len(img.shape) - 1)
else:
assert len(crop_size) == (len(
img.shape) - 1), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)"
# we need to find the center coords that we can crop to without exceeding the image border
lb_x = crop_size[0] // 2
ub_x = img.shape[1] - crop_size[0] // 2 - crop_size[0] % 2
lb_y = crop_size[1] // 2
ub_y = img.shape[2] - crop_size[1] // 2 - crop_size[1] % 2
if len(valid_voxels) == 0:
selected_center_voxel = (np.random.random_integers(lb_x, ub_x),
np.random.random_integers(lb_y, ub_y))
else:
selected_center_voxel = valid_voxels[np.random.choice(valid_voxels.shape[1]), :]
selected_center_voxel = np.array(selected_center_voxel)
for i in range(2):
selected_center_voxel[i] = max(crop_size[i] // 2, selected_center_voxel[i])
selected_center_voxel[i] = min(img.shape[i + 1] - crop_size[i] // 2 - crop_size[i] % 2,
selected_center_voxel[i])
result = img[:, (selected_center_voxel[0] - crop_size[0] // 2):(
selected_center_voxel[0] + crop_size[0] // 2 + crop_size[0] % 2),
(selected_center_voxel[1] - crop_size[1] // 2):(
selected_center_voxel[1] + crop_size[1] // 2 + crop_size[1] % 2)]
return result
class DataLoader3D(SlimDataLoaderBase):
def __init__(self, data, patch_size, final_patch_size, batch_size, has_prev_stage=False,
oversample_foreground_percent=0.0, memmap_mode="r", pad_mode="edge", pad_kwargs_data=None,
pad_sides=None):
"""
This is the basic data loader for 3D networks. It uses preprocessed data as produced by my (Fabian) preprocessing.
You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there
are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while
and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz
to npy. Don't forget to call delete_npy(folder) after you are done with training?
Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool
given that I (Fabian) will have to store that permanently on /datasets and my local computer. With this strategy all
data is stored in a compressed format (factor 10 smaller) and only unpacked when needed.
:param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go)
:param patch_size: what patch size will this data loader return? it is common practice to first load larger
patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use
get_patch_size() from data_augmentation.default_data_augmentation
:param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch
size that goes into your network. We need this here because we will pad patients in here so that patches at the
border of patients are sampled properly
:param batch_size:
:param num_batches: how many batches will the data loader produce before stopping? None=endless
:param seed:
:param stage: ignore this (Fabian only)
:param random: Sample keys randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset
:param oversample_foreground: half the batch will be forced to contain at least some foreground (equal prob for each of the foreground classes)
"""
super(DataLoader3D, self).__init__(data, batch_size, None)
if pad_kwargs_data is None:
pad_kwargs_data = OrderedDict()
self.pad_kwargs_data = pad_kwargs_data
self.pad_mode = pad_mode
self.oversample_foreground_percent = oversample_foreground_percent
self.final_patch_size = final_patch_size
self.has_prev_stage = has_prev_stage
self.patch_size = patch_size
self.list_of_keys = list(self._data.keys())
# need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
# (which is what the network will get) these patches will also cover the border of the patients
self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
if pad_sides is not None:
if not isinstance(pad_sides, np.ndarray):
pad_sides = np.array(pad_sides)
self.need_to_pad += pad_sides
self.memmap_mode = memmap_mode
self.num_channels = None
self.pad_sides = pad_sides
self.data_shape, self.seg_shape = self.determine_shapes()
def get_do_oversample(self, batch_idx):
return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
def determine_shapes(self):
if self.has_prev_stage:
num_seg = 2
else:
num_seg = 1
k = list(self._data.keys())[0]
if isfile(self._data[k]['data_file'][:-4] + ".npy"):
case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode)
else:
case_all_data = np.load(self._data[k]['data_file'])['data']
num_color_channels = case_all_data.shape[0] - 1
data_shape = (self.batch_size, num_color_channels, *self.patch_size)
seg_shape = (self.batch_size, num_seg, *self.patch_size)
return data_shape, seg_shape
def generate_train_batch(self):
selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None)
data = np.zeros(self.data_shape, dtype=np.float32)
seg = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, i in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
if self.get_do_oversample(j):
force_fg = True
else:
force_fg = False
if 'properties' in self._data[i].keys():
properties = self._data[i]['properties']
else:
properties = load_pickle(self._data[i]['properties_file'])
case_properties.append(properties)
# cases are stored as npz, but we require unpack_dataset to be run. This will decompress them into npy
# which is much faster to access
if isfile(self._data[i]['data_file'][:-4] + ".npy"):
case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)
else:
case_all_data = np.load(self._data[i]['data_file'])['data']
# If we are doing the cascade then we will also need to load the segmentation of the previous stage and
# concatenate it. Here it will be concatenates to the segmentation because the augmentations need to be
# applied to it in segmentation mode. Later in the data augmentation we move it from the segmentations to
# the last channel of the data
if self.has_prev_stage:
if isfile(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy"):
segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy",
mmap_mode=self.memmap_mode)[None]
else:
segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'])['data'][None]
# we theoretically support several possible previsous segmentations from which only one is sampled. But
# in practice this feature was never used so it's always only one segmentation
seg_key = np.random.choice(segs_from_previous_stage.shape[0])
seg_from_previous_stage = segs_from_previous_stage[seg_key:seg_key + 1]
assert all([i == j for i, j in zip(seg_from_previous_stage.shape[1:], case_all_data.shape[1:])]), \
"seg_from_previous_stage does not match the shape of case_all_data: %s vs %s" % \
(str(seg_from_previous_stage.shape[1:]), str(case_all_data.shape[1:]))
else:
seg_from_previous_stage = None
# do you trust me? You better do. Otherwise you'll have to go through this mess and honestly there are
# better things you could do right now
# (above) documentation of the day. Nice. Even myself coming back 1 months later I have not friggin idea
# what's going on. I keep the above documentation just for fun but attempt to make things clearer now
need_to_pad = self.need_to_pad
for d in range(3):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
shape = case_all_data.shape[1:]
lb_x = - need_to_pad[0] // 2
ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
lb_y = - need_to_pad[1] // 2
ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
lb_z = - need_to_pad[2] // 2
ub_z = shape[2] + need_to_pad[2] // 2 + need_to_pad[2] % 2 - self.patch_size[2]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg:
bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
else:
# these values should have been precomputed
if 'class_locations' not in properties.keys():
raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!")
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
foreground_classes = np.array(
[i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0])
foreground_classes = foreground_classes[foreground_classes > 0]
if len(foreground_classes) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
voxels_of_that_class = None
print('case does not contain any foreground classes', i)
else:
selected_class = np.random.choice(foreground_classes)
voxels_of_that_class = properties['class_locations'][selected_class]
if voxels_of_that_class is not None:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
bbox_z_lb = max(lb_z, selected_voxel[2] - self.patch_size[2] // 2)
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
bbox_x_ub = bbox_x_lb + self.patch_size[0]
bbox_y_ub = bbox_y_lb + self.patch_size[1]
bbox_z_ub = bbox_z_lb + self.patch_size[2]
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_x_lb = max(0, bbox_x_lb)
valid_bbox_x_ub = min(shape[0], bbox_x_ub)
valid_bbox_y_lb = max(0, bbox_y_lb)
valid_bbox_y_ub = min(shape[1], bbox_y_ub)
valid_bbox_z_lb = max(0, bbox_z_lb)
valid_bbox_z_ub = min(shape[2], bbox_z_ub)
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
case_all_data = np.copy(case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub,
valid_bbox_y_lb:valid_bbox_y_ub,
valid_bbox_z_lb:valid_bbox_z_ub])
if seg_from_previous_stage is not None:
seg_from_previous_stage = seg_from_previous_stage[:, valid_bbox_x_lb:valid_bbox_x_ub,
valid_bbox_y_lb:valid_bbox_y_ub,
valid_bbox_z_lb:valid_bbox_z_ub]
data[j] = np.pad(case_all_data[:-1], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
(-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
self.pad_mode, **self.pad_kwargs_data)
seg[j, 0] = np.pad(case_all_data[-1:], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
(-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
'constant', **{'constant_values': -1})
if seg_from_previous_stage is not None:
seg[j, 1] = np.pad(seg_from_previous_stage, ((0, 0),
(-min(0, bbox_x_lb),
max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb),
max(bbox_y_ub - shape[1], 0)),
(-min(0, bbox_z_lb),
max(bbox_z_ub - shape[2], 0))),
'constant', **{'constant_values': 0})
return {'data': data, 'seg': seg, 'properties': case_properties, 'keys': selected_keys}
class DataLoader2D(SlimDataLoaderBase):
def __init__(self, data, patch_size, final_patch_size, batch_size, oversample_foreground_percent=0.0,
memmap_mode="r", pseudo_3d_slices=1, pad_mode="edge",
pad_kwargs_data=None, pad_sides=None):
"""
This is the basic data loader for 2D networks. It uses preprocessed data as produced by my (Fabian) preprocessing.
You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there
are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while
and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz
to npy. Don't forget to call delete_npy(folder) after you are done with training?
Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool
given that I (Fabian) will have to store that permanently on /datasets and my local computer. With htis strategy all
data is stored in a compressed format (factor 10 smaller) and only unpacked when needed.
:param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go)
:param patch_size: what patch size will this data loader return? it is common practice to first load larger
patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use
get_patch_size() from data_augmentation.default_data_augmentation
:param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch
size that goes into your network. We need this here because we will pad patients in here so that patches at the
border of patients are sampled properly
:param batch_size:
:param num_batches: how many batches will the data loader produce before stopping? None=endless
:param seed:
:param stage: ignore this (Fabian only)
:param transpose: ignore this
:param random: sample randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset
:param pseudo_3d_slices: 7 = 3 below and 3 above the center slice
"""
super(DataLoader2D, self).__init__(data, batch_size, None)
if pad_kwargs_data is None:
pad_kwargs_data = OrderedDict()
self.pad_kwargs_data = pad_kwargs_data
self.pad_mode = pad_mode
self.pseudo_3d_slices = pseudo_3d_slices
self.oversample_foreground_percent = oversample_foreground_percent
self.final_patch_size = final_patch_size
self.patch_size = patch_size
self.list_of_keys = list(self._data.keys())
self.need_to_pad = np.array(patch_size) - np.array(final_patch_size)
self.memmap_mode = memmap_mode
if pad_sides is not None:
if not isinstance(pad_sides, np.ndarray):
pad_sides = np.array(pad_sides)
self.need_to_pad += pad_sides
self.pad_sides = pad_sides
self.data_shape, self.seg_shape = self.determine_shapes()
def determine_shapes(self):
num_seg = 1
k = list(self._data.keys())[0]
if isfile(self._data[k]['data_file'][:-4] + ".npy"):
case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode)
else:
case_all_data = np.load(self._data[k]['data_file'])['data']
num_color_channels = case_all_data.shape[0] - num_seg
data_shape = (self.batch_size, num_color_channels, *self.patch_size)
seg_shape = (self.batch_size, num_seg, *self.patch_size)
return data_shape, seg_shape
def get_do_oversample(self, batch_idx):
return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
def generate_train_batch(self):
selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None)
data = np.zeros(self.data_shape, dtype=np.float32)
seg = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, i in enumerate(selected_keys):
if 'properties' in self._data[i].keys():
properties = self._data[i]['properties']
else:
properties = load_pickle(self._data[i]['properties_file'])
case_properties.append(properties)
if self.get_do_oversample(j):
force_fg = True
else:
force_fg = False
if not isfile(self._data[i]['data_file'][:-4] + ".npy"):
# lets hope you know what you're doing
case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npz")['data']
else:
case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)
# this is for when there is just a 2d slice in case_all_data (2d support)
if len(case_all_data.shape) == 3:
case_all_data = case_all_data[:, None]
# first select a slice. This can be either random (no force fg) or guaranteed to contain some class
if not force_fg:
random_slice = np.random.choice(case_all_data.shape[1])
selected_class = None
else:
# these values should have been precomputed
if 'class_locations' not in properties.keys():
raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!")
foreground_classes = np.array(
[i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0])
foreground_classes = foreground_classes[foreground_classes > 0]
if len(foreground_classes) == 0:
selected_class = None
random_slice = np.random.choice(case_all_data.shape[1])
print('case does not contain any foreground classes', i)
else:
selected_class = np.random.choice(foreground_classes)
voxels_of_that_class = properties['class_locations'][selected_class]
valid_slices = np.unique(voxels_of_that_class[:, 0])
random_slice = np.random.choice(valid_slices)
voxels_of_that_class = voxels_of_that_class[voxels_of_that_class[:, 0] == random_slice]
voxels_of_that_class = voxels_of_that_class[:, 1:]
# now crop case_all_data to contain just the slice of interest. If we want additional slice above and
# below the current slice, here is where we get them. We stack those as additional color channels
if self.pseudo_3d_slices == 1:
case_all_data = case_all_data[:, random_slice]
else:
# this is very deprecated and will probably not work anymore. If you intend to use this you need to
# check this!
mn = random_slice - (self.pseudo_3d_slices - 1) // 2
mx = random_slice + (self.pseudo_3d_slices - 1) // 2 + 1
valid_mn = max(mn, 0)
valid_mx = min(mx, case_all_data.shape[1])
case_all_seg = case_all_data[-1:]
case_all_data = case_all_data[:-1]
case_all_data = case_all_data[:, valid_mn:valid_mx]
case_all_seg = case_all_seg[:, random_slice]
need_to_pad_below = valid_mn - mn
need_to_pad_above = mx - valid_mx
if need_to_pad_below > 0:
shp_for_pad = np.array(case_all_data.shape)
shp_for_pad[1] = need_to_pad_below
case_all_data = np.concatenate((np.zeros(shp_for_pad), case_all_data), 1)
if need_to_pad_above > 0:
shp_for_pad = np.array(case_all_data.shape)
shp_for_pad[1] = need_to_pad_above
case_all_data = np.concatenate((case_all_data, np.zeros(shp_for_pad)), 1)
case_all_data = case_all_data.reshape((-1, case_all_data.shape[-2], case_all_data.shape[-1]))
case_all_data = np.concatenate((case_all_data, case_all_seg), 0)
# case all data should now be (c, x, y)
assert len(case_all_data.shape) == 3
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
need_to_pad = self.need_to_pad
for d in range(2):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
shape = case_all_data.shape[1:]
lb_x = - need_to_pad[0] // 2
ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
lb_y = - need_to_pad[1] // 2
ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg or selected_class is None:
bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
else:
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
bbox_x_ub = bbox_x_lb + self.patch_size[0]
bbox_y_ub = bbox_y_lb + self.patch_size[1]
# whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
# bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
# valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
# later
valid_bbox_x_lb = max(0, bbox_x_lb)
valid_bbox_x_ub = min(shape[0], bbox_x_ub)
valid_bbox_y_lb = max(0, bbox_y_lb)
valid_bbox_y_ub = min(shape[1], bbox_y_ub)
# At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
# Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
# be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
# remove label -1 in the data augmentation but this way it is less error prone)
case_all_data = case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub,
valid_bbox_y_lb:valid_bbox_y_ub]
case_all_data_donly = np.pad(case_all_data[:-1], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))),
self.pad_mode, **self.pad_kwargs_data)
case_all_data_segonly = np.pad(case_all_data[-1:], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))),
'constant', **{'constant_values': -1})
data[j] = case_all_data_donly
seg[j] = case_all_data_segonly
keys = selected_keys
return {'data': data, 'seg': seg, 'properties': case_properties, "keys": keys}
if __name__ == "__main__":
t = "Task002_Heart"
p = join(preprocessing_output_dir, t, "stage1")
dataset = load_dataset(p)
with open(join(join(preprocessing_output_dir, t), "plans_stage1.pkl"), 'rb') as f:
plans = pickle.load(f)
unpack_dataset(p)
dl = DataLoader3D(dataset, (32, 32, 32), (32, 32, 32), 2, oversample_foreground_percent=0.33)
dl = DataLoader3D(dataset, np.array(plans['patch_size']).astype(int), np.array(plans['patch_size']).astype(int), 2,
oversample_foreground_percent=0.33)
dl2d = DataLoader2D(dataset, (64, 64), np.array(plans['patch_size']).astype(int)[1:], 12,
oversample_foreground_percent=0.33)
================================================
FILE: unetr_pp/training/learning_rate/poly_lr.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
return initial_lr * (1 - epoch / max_epochs)**exponent
================================================
FILE: unetr_pp/training/loss_functions/TopK_loss.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from unetr_pp.training.loss_functions.crossentropy import RobustCrossEntropyLoss
class TopKLoss(RobustCrossEntropyLoss):
"""
Network has to have NO LINEARITY!
"""
def __init__(self, weight=None, ignore_index=-100, k=10):
self.k = k
super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)
def forward(self, inp, target):
target = target[:, 0].long()
res = super(TopKLoss, self).forward(inp, target)
num_voxels = np.prod(res.shape, dtype=np.int64)
res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
return res.mean()
================================================
FILE: unetr_pp/training/loss_functions/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/training/loss_functions/crossentropy.py
================================================
from torch import nn, Tensor
class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
"""
this is just a compatibility layer because my target tensor is float and has an extra dimension
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if len(target.shape) == len(input.shape):
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target.long())
================================================
FILE: unetr_pp/training/loss_functions/deep_supervision.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
class MultipleOutputLoss2(nn.Module):
def __init__(self, loss, weight_factors=None):
"""
use this if you have several outputs and ground truth (both list of same len) and the loss should be computed
between them (x[0] and y[0], x[1] and y[1] etc)
:param loss:
:param weight_factors:
"""
super(MultipleOutputLoss2, self).__init__()
self.weight_factors = weight_factors
self.loss = loss
def forward(self, x, y):
assert isinstance(x, (tuple, list)), "x must be either tuple or list"
assert isinstance(y, (tuple, list)), "y must be either tuple or list"
if self.weight_factors is None:
weights = [1] * len(x)
else:
weights = self.weight_factors
l = weights[0] * self.loss(x[0], y[0])
for i in range(1, len(x)):
if weights[i] != 0:
l += weights[i] * self.loss(x[i], y[i])
return l
================================================
FILE: unetr_pp/training/loss_functions/dice_loss.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from unetr_pp.training.loss_functions.TopK_loss import TopKLoss
from unetr_pp.training.loss_functions.crossentropy import RobustCrossEntropyLoss
from unetr_pp.utilities.nd_softmax import softmax_helper
from unetr_pp.utilities.tensor_utilities import sum_tensor
from torch import nn
import numpy as np
class GDL(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False, square_volumes=False):
"""
square_volumes will square the weight term. The paper recommends square_volumes=True; I don't (just an intuition)
"""
super(GDL, self).__init__()
self.square_volumes = square_volumes
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
shp_y = y.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if len(shp_x) != len(shp_y):
y = y.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(x.shape, y.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = y
else:
gt = y.long()
y_onehot = torch.zeros(shp_x)
if x.device.type == "cuda":
y_onehot = y_onehot.cuda(x.device.index)
y_onehot.scatter_(1, gt, 1)
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
if not self.do_bg:
x = x[:, 1:]
y_onehot = y_onehot[:, 1:]
tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask, self.square)
# GDL weight computation, we use 1/V
volumes = sum_tensor(y_onehot, axes) + 1e-6 # add some eps to prevent div by zero
if self.square_volumes:
volumes = volumes ** 2
# apply weights
tp = tp / volumes
fp = fp / volumes
fn = fn / volumes
# sum over classes
if self.batch_dice:
axis = 0
else:
axis = 1
tp = tp.sum(axis, keepdim=False)
fp = fp.sum(axis, keepdim=False)
fn = fn.sum(axis, keepdim=False)
# compute dice
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
dc = dc.mean()
return -dc
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes: can be (, ) = no summation
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tn = tn ** 2
if len(axes) > 0:
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
tn = sum_tensor(tn, axes, keepdim=False)
return tp, fp, fn, tn
class SoftDiceLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
"""
"""
super(SoftDiceLoss, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
nominator = 2 * tp + self.smooth
denominator = 2 * tp + fp + fn + self.smooth
dc = nominator / (denominator + 1e-8)
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return -dc
class MCCLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_mcc=False, do_bg=True, smooth=0.0):
"""
based on matthews correlation coefficient
https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
Does not work. Really unstable. F this.
"""
super(MCCLoss, self).__init__()
self.smooth = smooth
self.do_bg = do_bg
self.batch_mcc = batch_mcc
self.apply_nonlin = apply_nonlin
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
voxels = np.prod(shp_x[2:])
if self.batch_mcc:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn, tn = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
tp /= voxels
fp /= voxels
fn /= voxels
tn /= voxels
nominator = tp * tn - fp * fn + self.smooth
denominator = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5 + self.smooth
mcc = nominator / denominator
if not self.do_bg:
if self.batch_mcc:
mcc = mcc[1:]
else:
mcc = mcc[:, 1:]
mcc = mcc.mean()
return -mcc
class SoftDiceLossSquared(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
"""
squares the terms in the denominator as proposed by Milletari et al.
"""
super(SoftDiceLossSquared, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
shp_y = y.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
with torch.no_grad():
if len(shp_x) != len(shp_y):
y = y.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(x.shape, y.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = y
else:
y = y.long()
y_onehot = torch.zeros(shp_x)
if x.device.type == "cuda":
y_onehot = y_onehot.cuda(x.device.index)
y_onehot.scatter_(1, y, 1).float()
intersect = x * y_onehot
# values in the denominator get smoothed
denominator = x ** 2 + y_onehot ** 2
# aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
# squaring)
intersect = sum_tensor(intersect, axes, False) + self.smooth
denominator = sum_tensor(denominator, axes, False) + self.smooth
dc = 2 * intersect / denominator
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return -dc
class DC_and_CE_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False, weight_ce=1, weight_dice=1,
log_dice=False, ignore_label=None):
"""
CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param aggregate:
:param square_dice:
:param weight_ce:
:param weight_dice:
"""
super(DC_and_CE_loss, self).__init__()
if ignore_label is not None:
assert not square_dice, 'not implemented'
ce_kwargs['reduction'] = 'none'
self.log_dice = log_dice
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.aggregate = aggregate
self.ce = RobustCrossEntropyLoss(**ce_kwargs)
self.ignore_label = ignore_label
if not square_dice:
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
else:
self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def forward(self, net_output, target):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
if self.ignore_label is not None:
assert target.shape[1] == 1, 'not implemented for one hot encoding'
mask = target != self.ignore_label
target[~mask] = 0
mask = mask.float()
else:
mask = None
dc_loss = self.dc(net_output, target, loss_mask=mask) if self.weight_dice != 0 else 0
if self.log_dice:
dc_loss = -torch.log(-dc_loss)
ce_loss = self.ce(net_output, target[:, 0].long()) if self.weight_ce != 0 else 0
if self.ignore_label is not None:
ce_loss *= mask[:, 0]
ce_loss = ce_loss.sum() / mask.sum()
if self.aggregate == "sum":
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
class DC_and_BCE_loss(nn.Module):
def __init__(self, bce_kwargs, soft_dice_kwargs, aggregate="sum"):
"""
DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
THIS LOSS IS INTENDED TO BE USED FOR BRATS REGIONS ONLY
:param soft_dice_kwargs:
:param bce_kwargs:
:param aggregate:
"""
super(DC_and_BCE_loss, self).__init__()
self.aggregate = aggregate
self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)
def forward(self, net_output, target):
ce_loss = self.ce(net_output, target)
dc_loss = self.dc(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
class GDL_and_CE_loss(nn.Module):
def __init__(self, gdl_dice_kwargs, ce_kwargs, aggregate="sum"):
super(GDL_and_CE_loss, self).__init__()
self.aggregate = aggregate
self.ce = RobustCrossEntropyLoss(**ce_kwargs)
self.dc = GDL(softmax_helper, **gdl_dice_kwargs)
def forward(self, net_output, target):
dc_loss = self.dc(net_output, target)
ce_loss = self.ce(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
class DC_and_topk_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False):
super(DC_and_topk_loss, self).__init__()
self.aggregate = aggregate
self.ce = TopKLoss(**ce_kwargs)
if not square_dice:
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
else:
self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def forward(self, net_output, target):
dc_loss = self.dc(net_output, target)
ce_loss = self.ce(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later?)
return result
================================================
FILE: unetr_pp/training/model_restore.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unetr_pp
import torch
from batchgenerators.utilities.file_and_folder_operations import *
import importlib
import pkgutil
from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
def recursive_find_python_class(folder, trainer_name, current_module):
tr = None
for importer, modname, ispkg in pkgutil.iter_modules(folder):
# print(modname, ispkg)
if not ispkg:
m = importlib.import_module(current_module + "." + modname)
if hasattr(m, trainer_name):
tr = getattr(m, trainer_name)
break
if tr is None:
for importer, modname, ispkg in pkgutil.iter_modules(folder):
if ispkg:
next_current_module = current_module + "." + modname
tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module)
if tr is not None:
break
return tr
def restore_model(pkl_file, checkpoint=None, train=False, fp16=None,folder=None):
"""
This is a utility function to load any nnFormer trainer from a pkl. It will recursively search
unetr_pp.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint
is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train).
The pkl file required here is the one that will be saved automatically when calling nnFormerTrainer.save_checkpoint.
:param pkl_file:
:param checkpoint:
:param train:
:param fp16: if None then we take no action. If True/False we overwrite what the model has in its init
:return:
"""
info = load_pickle(pkl_file)
init = info['init']
name = info['name']
'''
update on 2022.6.23.
For the model_best.model can be shared in different machines.
'''
task=folder.split('/')[-2]
network = folder.split('/')[-3]
if network == '2d':
plans_file = join(preprocessing_output_dir, task, default_plans_identifier + "_plans_2D.pkl")
else:
plans_file = join(preprocessing_output_dir, task, default_plans_identifier + "_plans_3D.pkl")
info['init'] = list(info['init'])
info['init'][0]=plans_file
info['init'] = tuple(info['init'])
if 'nnUNet' in name:
name=name.replace('nnUNet','nnFormer')
if len(init)>10:
init=list(init)
del init[2]
del init[-2]
search_in = join(unetr_pp.__path__[0], "training", "network_training")
tr = recursive_find_python_class([search_in], name, current_module="unetr_pp.training.network_training")
if tr is None:
"""
Fabian only. This will trigger searching for trainer classes in other repositories as well
"""
try:
import meddec
search_in = join(meddec.__path__[0], "model_training")
tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training")
except ImportError:
pass
if tr is None:
raise RuntimeError("Could not find the model trainer specified in checkpoint in unetr_pp.trainig.network_training. If it "
"is not located there, please move it or change the code of restore_model. Your model "
"trainer can be located in any directory within unetr_pp.trainig.network_training (search is recursive)."
"\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name))
trainer = tr(*init)
# We can hack fp16 overwriting into the trainer without changing the init arguments because nothing happens with
# fp16 in the init, it just saves it to a member variable
if fp16 is not None:
trainer.fp16 = fp16
trainer.process_plans(info['plans'])
if checkpoint is not None:
trainer.load_checkpoint(checkpoint, train)
return trainer
def load_best_model_for_inference(folder):
checkpoint = join(folder, "model_best.model")
pkl_file = checkpoint + ".pkl"
return restore_model(pkl_file, checkpoint, False)
def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"):
"""
used for if you need to ensemble the five models of a cross-validation. This will restore the model from the
checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast
switching between parameters (as opposed to loading them form disk each time).
This is best used for inference and test prediction
:param folder:
:param folds:
:param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
:return:
"""
if isinstance(folds, str):
folds = [join(folder, "all")]
assert isdir(folds[0]), "no output folder for fold %s found" % folds
elif isinstance(folds, (list, tuple)):
if len(folds) == 1 and folds[0] == "all":
folds = [join(folder, "all")]
else:
folds = [join(folder, "fold_%d" % i) for i in folds]
assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present"
elif isinstance(folds, int):
folds = [join(folder, "fold_%d" % folds)]
assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds
elif folds is None:
print("folds is None so we will automatically look for output folders (not using \'all\'!)")
folds = subfolders(folder, prefix="fold")
print("found the following folds: ", folds)
else:
raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds)))
trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision,folder=folder)
trainer.output_folder = folder
trainer.output_folder_base = folder
trainer.update_fold(0)
trainer.initialize(False)
all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds]
print("using the following model files: ", all_best_model_files)
all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files]
return trainer, all_params
if __name__ == "__main__":
pkl = "/home/fabian/PhD/results/nnFormerV2/nnFormerV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl"
checkpoint = pkl[:-4]
train = False
trainer = restore_model(pkl, checkpoint, train)
================================================
FILE: unetr_pp/training/network_training/Trainer_acdc.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections import OrderedDict
from multiprocessing import Pool
from time import sleep
from typing import Tuple, List
import matplotlib
import unetr_pp
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.postprocessing.connected_components import determine_postprocessing
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
default_2D_augmentation_params, get_default_augmentation, get_patch_size
from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset
from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss
from unetr_pp.training.network_training.network_trainer_acdc import NetworkTrainer_acdc
from unetr_pp.utilities.nd_softmax import softmax_helper
from unetr_pp.utilities.tensor_utilities import sum_tensor
from torch import nn
from torch.optim import lr_scheduler
matplotlib.use("agg")
class Trainer_acdc(NetworkTrainer_acdc):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
"""
:param deterministic:
:param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
None if you wish to load some checkpoint and do inference only
:param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
:param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
not the entire path). This is where the preprocessed data lies that will be used for network training. We made
this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
Can be None if you are doing inference only.
:param output_folder: where to store parameters, plot progress and to the validation
:param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
because the split information is stored in this directory. For running prediction only this input is not
required and may be set to None
:param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
batch is a pseudo volume?
:param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
specified for training:
if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
:param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
is considerably slower! Running unpack_data=False with 2d should never be done!
IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args
in your init accordingly. Otherwise checkpoints won't load properly!
"""
super(Trainer_acdc, self).__init__(deterministic, fp16)
self.unpack_data = unpack_data
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
# set through arguments from init
self.stage = stage
self.experiment_name = self.__class__.__name__
self.plans_file = plans_file
self.output_folder = output_folder
self.dataset_directory = dataset_directory
self.output_folder_base = self.output_folder
self.fold = fold
self.plans = None
# if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
# irrelevant
if self.dataset_directory is not None and isdir(self.dataset_directory):
self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
else:
self.gt_niftis_folder = None
self.folder_with_preprocessed_data = None
# set in self.initialize()
self.dl_tr = self.dl_val = None
self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file
self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
self.batch_dice = batch_dice
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
self.min_region_size_per_class = self.min_size_per_class = None
self.inference_pad_border_mode = "constant"
self.inference_pad_kwargs = {'constant_values': 0}
self.update_fold(fold)
self.pad_all_sides = None
self.lr_scheduler_eps = 1e-3
self.lr_scheduler_patience = 30
self.initial_lr = 3e-4
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.conv_per_stage = None
self.regions_class_order = None
def update_fold(self, fold):
"""
used to swap between folds for inference (ensemble of models from cross-validation)
DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS
:param fold:
:return:
"""
if fold is not None:
if isinstance(fold, str):
assert fold == "all", "if self.fold is a string then it must be \'all\'"
if self.output_folder.endswith("%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "%s" % str(fold))
else:
if self.output_folder.endswith("fold_%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "fold_%s" % str(fold))
self.fold = fold
def setup_DA_params(self):
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
def initialize(self, training=True, force_load_plans=False):
"""
For prediction of test cases just set training=False, this will prevent loading of training data and
training batchgenerator initialization
:param training:
:return:
"""
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
if training:
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
self.print_to_log_file("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
self.print_to_log_file("done")
else:
self.print_to_log_file(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
# assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
self.was_initialized = True
def initialize_network(self):
"""
This is specific to the U-Net and must be adapted for other network architectures
:return:
"""
# self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
# self.print_to_log_file(self.net_conv_kernel_sizes)
net_numpool = len(self.net_num_pool_op_kernel_sizes)
if self.threeD:
conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.InstanceNorm3d
else:
conv_op = nn.Conv2d
dropout_op = nn.Dropout2d
norm_op = nn.InstanceNorm2d
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs,
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
self.network.inference_apply_nonlin = softmax_helper
if torch.cuda.is_available():
self.network.cuda()
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
amsgrad=True)
self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
patience=self.lr_scheduler_patience,
verbose=True, threshold=self.lr_scheduler_eps,
threshold_mode="abs")
def plot_network_architecture(self):
try:
from batchgenerators.utilities.file_and_folder_operations import join
import hiddenlayer as hl
if torch.cuda.is_available():
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(),
transforms=None)
else:
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)),
transforms=None)
g.save(join(self.output_folder, "network_architecture.pdf"))
del g
except Exception as e:
self.print_to_log_file("Unable to plot network architecture:")
self.print_to_log_file(e)
self.print_to_log_file("\nprinting the network instead:\n")
self.print_to_log_file(self.network)
self.print_to_log_file("\n")
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def save_debug_information(self):
# saving some debug information
dct = OrderedDict()
for k in self.__dir__():
if not k.startswith("__"):
if not callable(getattr(self, k)):
dct[k] = str(getattr(self, k))
del dct['plans']
del dct['intensity_properties']
del dct['dataset']
del dct['dataset_tr']
del dct['dataset_val']
save_json(dct, join(self.output_folder, "debug.json"))
import shutil
shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl"))
def run_training(self):
self.save_debug_information()
super(Trainer_acdc, self).run_training()
def load_plans_file(self):
"""
This is what actually configures the entire experiment. The plans file is generated by experiment planning
:return:
"""
self.plans = load_pickle(self.plans_file)
def process_plans(self, plans):
if self.stage is None:
assert len(list(plans['plans_per_stage'].keys())) == 1, \
"If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \
"case. Please specify which stage of the cascade must be trained"
self.stage = list(plans['plans_per_stage'].keys())[0]
self.plans = plans
stage_plans = self.plans['plans_per_stage'][self.stage]
self.batch_size = stage_plans['batch_size']
self.net_pool_per_axis = stage_plans['num_pool_per_axis']
self.patch_size = np.array(stage_plans['patch_size']).astype(int)
self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug']
if 'pool_op_kernel_sizes' not in stage_plans.keys():
assert 'num_pool_per_axis' in stage_plans.keys()
self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...")
self.net_num_pool_op_kernel_sizes = []
for i in range(max(self.net_pool_per_axis)):
curr = []
for j in self.net_pool_per_axis:
if (max(self.net_pool_per_axis) - j) <= i:
curr.append(2)
else:
curr.append(1)
self.net_num_pool_op_kernel_sizes.append(curr)
else:
self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
if 'conv_kernel_sizes' not in stage_plans.keys():
self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...")
self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1)
else:
self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
self.pad_all_sides = None # self.patch_size
self.intensity_properties = plans['dataset_properties']['intensityproperties']
self.normalization_schemes = plans['normalization_schemes']
self.base_num_features = plans['base_num_features']
self.num_input_channels = plans['num_modalities']
self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes
self.classes = plans['all_classes']
self.use_mask_for_norm = plans['use_mask_for_norm']
self.only_keep_largest_connected_component = plans['keep_only_largest_region']
self.min_region_size_per_class = plans['min_region_size_per_class']
self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class']
if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None:
print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. "
"You should rerun preprocessing. We will proceed and assume that both transpose_foward "
"and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!")
plans['transpose_forward'] = [0, 1, 2]
plans['transpose_backward'] = [0, 1, 2]
self.transpose_forward = plans['transpose_forward']
self.transpose_backward = plans['transpose_backward']
if len(self.patch_size) == 2:
self.threeD = False
elif len(self.patch_size) == 3:
self.threeD = True
else:
raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size))
if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently
self.conv_per_stage = plans['conv_per_stage']
else:
self.conv_per_stage = 2
def load_dataset(self):
self.dataset = load_dataset(self.folder_with_preprocessed_data)
def get_basic_generators(self):
self.load_dataset()
self.do_split()
if self.threeD:
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
False, oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
else:
dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
return dl_tr, dl_val
def preprocess_patient(self, input_files):
"""
Used to predict new unseen data. Not used for the preprocessing of the training/test data
:param input_files:
:return:
"""
from unetr_pp.training.model_restore import recursive_find_python_class
preprocessor_name = self.plans.get('preprocessor_name')
if preprocessor_name is None:
if self.threeD:
preprocessor_name = "GenericPreprocessor"
else:
preprocessor_name = "PreprocessorFor2D"
print("using preprocessor", preprocessor_name)
preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")],
preprocessor_name,
current_module="unetr_pp.preprocessing")
assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \
preprocessor_name
preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm,
self.transpose_forward, self.intensity_properties)
d, s, properties = preprocessor.preprocess_test_case(input_files,
self.plans['plans_per_stage'][self.stage][
'current_spacing'])
return d, s, properties
def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None,
softmax_ouput_file: str = None, mixed_precision: bool = True) -> None:
"""
Use this to predict new data
:param input_files:
:param output_file:
:param softmax_ouput_file:
:param mixed_precision:
:return:
"""
print("preprocessing...")
d, s, properties = self.preprocess_patient(input_files)
print("predicting...")
pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
mirror_axes=self.data_aug_params['mirror_axes'],
use_sliding_window=True, step_size=0.5,
use_gaussian=True, pad_border_mode='constant',
pad_kwargs={'constant_values': 0},
verbose=True, all_in_gpu=False,
mixed_precision=mixed_precision)[1]
pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
print("resampling to original spacing and nifti export...")
save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order,
self.regions_class_order, None, None, softmax_ouput_file,
None, force_separate_z=force_separate_z,
interpolation_order_z=interpolation_order_z)
print("done")
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
:param data:
:param do_mirroring:
:param mirror_axes:
:param use_sliding_window:
:param step_size:
:param use_gaussian:
:param pad_border_mode:
:param pad_kwargs:
:param all_in_gpu:
:param verbose:
:return:
"""
if pad_border_mode == 'constant' and pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
if do_mirroring and mirror_axes is None:
mirror_axes = self.data_aug_params['mirror_axes']
if do_mirroring:
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
"was done without mirroring"
valid = list((SegmentationNetwork, nn.DataParallel))
assert isinstance(self.network, tuple(valid))
current_mode = self.network.training
self.network.eval()
ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window, step_size=step_size,
patch_size=self.patch_size, regions_class_order=self.regions_class_order,
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
mixed_precision=mixed_precision)
self.network.train(current_mode)
return ret
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
if debug=True then the temporary files generated for postprocessing determination will be kept
"""
current_mode = self.network.training
self.network.eval()
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
if self.dataset_val is None:
self.load_dataset()
self.do_split()
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
# predictions as they come from the network go here
output_folder = join(self.output_folder, validation_folder_name)
maybe_mkdir_p(output_folder)
# this is for debug purposes
my_input_args = {'do_mirroring': do_mirroring,
'use_sliding_window': use_sliding_window,
'step_size': step_size,
'save_softmax': save_softmax,
'use_gaussian': use_gaussian,
'overwrite': overwrite,
'validation_folder_name': validation_folder_name,
'debug': debug,
'all_in_gpu': all_in_gpu,
'segmentation_export_kwargs': segmentation_export_kwargs,
}
save_json(my_input_args, join(output_folder, "validation_args.json"))
if do_mirroring:
if not self.data_aug_params['do_mirror']:
raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
mirror_axes = self.data_aug_params['mirror_axes']
else:
mirror_axes = ()
pred_gt_tuples = []
export_pool = Pool(default_num_threads)
results = []
for k in self.dataset_val.keys():
properties = load_pickle(self.dataset[k]['properties_file'])
fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
(save_softmax and not isfile(join(output_folder, fname + ".npz"))):
data = np.load(self.dataset[k]['data_file'])['data']
print(k, data.shape)
data[-1][data[-1] == -1] = 0
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size,
use_gaussian=use_gaussian,
all_in_gpu=all_in_gpu,
mixed_precision=self.fp16)[1]
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if save_softmax:
softmax_fname = join(output_folder, fname + ".npz")
else:
softmax_fname = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(join(output_folder, fname + ".npy"), softmax_pred)
softmax_pred = join(output_folder, fname + ".npy")
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_pred, join(output_folder, fname + ".nii.gz"),
properties, interpolation_order, self.regions_class_order,
None, None,
softmax_fname, None, force_separate_z,
interpolation_order_z),
)
)
)
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
join(self.gt_niftis_folder, fname + ".nii.gz")])
_ = [i.get() for i in results]
self.print_to_log_file("finished prediction")
# evaluate raw predictions
self.print_to_log_file("evaluation of raw predictions")
task = self.dataset_directory.split("/")[-1]
job_name = self.experiment_name
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
json_output_file=join(output_folder, "summary.json"),
json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
json_author="Fabian",
json_task=task, num_threads=default_num_threads)
if run_postprocessing_on_folds:
# in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything
# except the largest connected component for each class. To see if this improves results, we do this for all
# classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
# have this applied during inference as well
self.print_to_log_file("determining postprocessing")
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
# after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
# They are always in that folder, even if no postprocessing as applied!
# detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
# postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
# done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
# be used later
gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
maybe_mkdir_p(gt_nifti_folder)
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
success = False
attempts = 0
e = None
while not success and attempts < 10:
try:
shutil.copy(f, gt_nifti_folder)
success = True
except OSError as e:
attempts += 1
sleep(1)
if not success:
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
if e is not None:
raise e
self.network.train(current_mode)
def run_online_evaluation(self, output, target):
with torch.no_grad():
num_classes = output.shape[1]
output_softmax = softmax_helper(output)
output_seg = output_softmax.argmax(1)
target = target[:, 0]
axes = tuple(range(1, len(target.shape)))
tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
for c in range(1, num_classes):
tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
self.online_eval_tp.append(list(tp_hard))
self.online_eval_fp.append(list(fp_hard))
self.online_eval_fn.append(list(fn_hard))
def finish_online_evaluation(self):
self.online_eval_tp = np.sum(self.online_eval_tp, 0)
self.online_eval_fp = np.sum(self.online_eval_fp, 0)
self.online_eval_fn = np.sum(self.online_eval_fn, 0)
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)]
if not np.isnan(i)]
self.all_val_eval_metrics.append(np.mean(global_dc_per_class))
self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class))
self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not "
"exact.)")
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
def save_checkpoint(self, fname, save_optimizer=True):
super(Trainer_acdc, self).save_checkpoint(fname, save_optimizer)
info = OrderedDict()
info['init'] = self.init_args
info['name'] = self.__class__.__name__
info['class'] = str(self.__class__)
info['plans'] = self.plans
write_pickle(info, fname + ".pkl")
================================================
FILE: unetr_pp/training/network_training/Trainer_lung.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from collections import OrderedDict
from multiprocessing import Pool
from time import sleep
from typing import Tuple, List
import matplotlib
import unetr_pp
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.postprocessing.connected_components import determine_postprocessing
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
default_2D_augmentation_params, get_default_augmentation, get_patch_size
from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset
from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss
from unetr_pp.training.network_training.network_trainer_lung import NetworkTrainer_lung
from unetr_pp.utilities.nd_softmax import softmax_helper
from unetr_pp.utilities.tensor_utilities import sum_tensor
from torch import nn
from torch.optim import lr_scheduler
matplotlib.use("agg")
class Trainer_lung(NetworkTrainer_lung):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
"""
:param deterministic:
:param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
None if you wish to load some checkpoint and do inference only
:param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
:param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
not the entire path). This is where the preprocessed data lies that will be used for network training. We made
this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
Can be None if you are doing inference only.
:param output_folder: where to store parameters, plot progress and to the validation
:param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
because the split information is stored in this directory. For running prediction only this input is not
required and may be set to None
:param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
batch is a pseudo volume?
:param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
specified for training:
if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
:param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
is considerably slower! Running unpack_data=False with 2d should never be done!
IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args
in your init accordingly. Otherwise checkpoints won't load properly!
"""
super(Trainer_lung, self).__init__(deterministic, fp16)
self.unpack_data = unpack_data
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
# set through arguments from init
self.stage = stage
self.experiment_name = self.__class__.__name__
self.plans_file = plans_file
self.output_folder = output_folder
self.dataset_directory = dataset_directory
self.output_folder_base = self.output_folder
self.fold = fold
self.plans = None
# if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
# irrelevant
if self.dataset_directory is not None and isdir(self.dataset_directory):
self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
else:
self.gt_niftis_folder = None
self.folder_with_preprocessed_data = None
# set in self.initialize()
self.dl_tr = self.dl_val = None
self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file
self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
self.batch_dice = batch_dice
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
self.min_region_size_per_class = self.min_size_per_class = None
self.inference_pad_border_mode = "constant"
self.inference_pad_kwargs = {'constant_values': 0}
self.update_fold(fold)
self.pad_all_sides = None
self.lr_scheduler_eps = 1e-3
self.lr_scheduler_patience = 30
self.initial_lr = 3e-4
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.conv_per_stage = None
self.regions_class_order = None
def update_fold(self, fold):
"""
used to swap between folds for inference (ensemble of models from cross-validation)
DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS
:param fold:
:return:
"""
if fold is not None:
if isinstance(fold, str):
assert fold == "all", "if self.fold is a string then it must be \'all\'"
if self.output_folder.endswith("%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "%s" % str(fold))
else:
if self.output_folder.endswith("fold_%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "fold_%s" % str(fold))
self.fold = fold
def setup_DA_params(self):
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
def initialize(self, training=True, force_load_plans=False):
"""
For prediction of test cases just set training=False, this will prevent loading of training data and
training batchgenerator initialization
:param training:
:return:
"""
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
if training:
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
self.print_to_log_file("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
self.print_to_log_file("done")
else:
self.print_to_log_file(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
# assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
self.was_initialized = True
def initialize_network(self):
"""
This is specific to the U-Net and must be adapted for other network architectures
:return:
"""
# self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
# self.print_to_log_file(self.net_conv_kernel_sizes)
net_numpool = len(self.net_num_pool_op_kernel_sizes)
if self.threeD:
conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.InstanceNorm3d
else:
conv_op = nn.Conv2d
dropout_op = nn.Dropout2d
norm_op = nn.InstanceNorm2d
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs,
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
self.network.inference_apply_nonlin = softmax_helper
if torch.cuda.is_available():
self.network.cuda()
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
amsgrad=True)
self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
patience=self.lr_scheduler_patience,
verbose=True, threshold=self.lr_scheduler_eps,
threshold_mode="abs")
def plot_network_architecture(self):
try:
from batchgenerators.utilities.file_and_folder_operations import join
import hiddenlayer as hl
if torch.cuda.is_available():
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(),
transforms=None)
else:
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)),
transforms=None)
g.save(join(self.output_folder, "network_architecture.pdf"))
del g
except Exception as e:
self.print_to_log_file("Unable to plot network architecture:")
self.print_to_log_file(e)
self.print_to_log_file("\nprinting the network instead:\n")
self.print_to_log_file(self.network)
self.print_to_log_file("\n")
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def save_debug_information(self):
# saving some debug information
dct = OrderedDict()
for k in self.__dir__():
if not k.startswith("__"):
if not callable(getattr(self, k)):
dct[k] = str(getattr(self, k))
del dct['plans']
del dct['intensity_properties']
del dct['dataset']
del dct['dataset_tr']
del dct['dataset_val']
save_json(dct, join(self.output_folder, "debug.json"))
import shutil
shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl"))
def run_training(self):
self.save_debug_information()
super(Trainer_lung, self).run_training()
def load_plans_file(self):
"""
This is what actually configures the entire experiment. The plans file is generated by experiment planning
:return:
"""
self.plans = load_pickle(self.plans_file)
def process_plans(self, plans):
if self.stage is None:
assert len(list(plans['plans_per_stage'].keys())) == 1, \
"If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \
"case. Please specify which stage of the cascade must be trained"
self.stage = list(plans['plans_per_stage'].keys())[0]
self.plans = plans
stage_plans = self.plans['plans_per_stage'][self.stage]
self.batch_size = stage_plans['batch_size']
self.net_pool_per_axis = stage_plans['num_pool_per_axis']
self.patch_size = np.array(stage_plans['patch_size']).astype(int)
self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug']
if 'pool_op_kernel_sizes' not in stage_plans.keys():
assert 'num_pool_per_axis' in stage_plans.keys()
self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...")
self.net_num_pool_op_kernel_sizes = []
for i in range(max(self.net_pool_per_axis)):
curr = []
for j in self.net_pool_per_axis:
if (max(self.net_pool_per_axis) - j) <= i:
curr.append(2)
else:
curr.append(1)
self.net_num_pool_op_kernel_sizes.append(curr)
else:
self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
if 'conv_kernel_sizes' not in stage_plans.keys():
self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...")
self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1)
else:
self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
self.pad_all_sides = None # self.patch_size
self.intensity_properties = plans['dataset_properties']['intensityproperties']
self.normalization_schemes = plans['normalization_schemes']
self.base_num_features = plans['base_num_features']
self.num_input_channels = plans['num_modalities']
self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes
self.classes = plans['all_classes']
self.use_mask_for_norm = plans['use_mask_for_norm']
self.only_keep_largest_connected_component = plans['keep_only_largest_region']
self.min_region_size_per_class = plans['min_region_size_per_class']
self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class']
if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None:
print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. "
"You should rerun preprocessing. We will proceed and assume that both transpose_foward "
"and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!")
plans['transpose_forward'] = [0, 1, 2]
plans['transpose_backward'] = [0, 1, 2]
self.transpose_forward = plans['transpose_forward']
self.transpose_backward = plans['transpose_backward']
if len(self.patch_size) == 2:
self.threeD = False
elif len(self.patch_size) == 3:
self.threeD = True
else:
raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size))
if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently
self.conv_per_stage = plans['conv_per_stage']
else:
self.conv_per_stage = 2
def load_dataset(self):
self.dataset = load_dataset(self.folder_with_preprocessed_data)
def get_basic_generators(self):
self.load_dataset()
self.do_split()
if self.threeD:
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
False, oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
else:
dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
return dl_tr, dl_val
def preprocess_patient(self, input_files):
"""
Used to predict new unseen data. Not used for the preprocessing of the training/test data
:param input_files:
:return:
"""
from unetr_pp.training.model_restore import recursive_find_python_class
preprocessor_name = self.plans.get('preprocessor_name')
if preprocessor_name is None:
if self.threeD:
preprocessor_name = "GenericPreprocessor"
else:
preprocessor_name = "PreprocessorFor2D"
print("using preprocessor", preprocessor_name)
preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")],
preprocessor_name,
current_module="unetr_pp.preprocessing")
assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \
preprocessor_name
preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm,
self.transpose_forward, self.intensity_properties)
d, s, properties = preprocessor.preprocess_test_case(input_files,
self.plans['plans_per_stage'][self.stage][
'current_spacing'])
return d, s, properties
def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None,
softmax_ouput_file: str = None, mixed_precision: bool = True) -> None:
"""
Use this to predict new data
:param input_files:
:param output_file:
:param softmax_ouput_file:
:param mixed_precision:
:return:
"""
print("preprocessing...")
d, s, properties = self.preprocess_patient(input_files)
print("predicting...")
pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
mirror_axes=self.data_aug_params['mirror_axes'],
use_sliding_window=True, step_size=0.5,
use_gaussian=True, pad_border_mode='constant',
pad_kwargs={'constant_values': 0},
verbose=True, all_in_gpu=False,
mixed_precision=mixed_precision)[1]
pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
print("resampling to original spacing and nifti export...")
save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order,
self.regions_class_order, None, None, softmax_ouput_file,
None, force_separate_z=force_separate_z,
interpolation_order_z=interpolation_order_z)
print("done")
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
:param data:
:param do_mirroring:
:param mirror_axes:
:param use_sliding_window:
:param step_size:
:param use_gaussian:
:param pad_border_mode:
:param pad_kwargs:
:param all_in_gpu:
:param verbose:
:return:
"""
if pad_border_mode == 'constant' and pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
if do_mirroring and mirror_axes is None:
mirror_axes = self.data_aug_params['mirror_axes']
if do_mirroring:
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
"was done without mirroring"
valid = list((SegmentationNetwork, nn.DataParallel))
assert isinstance(self.network, tuple(valid))
current_mode = self.network.training
self.network.eval()
ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window, step_size=step_size,
patch_size=self.patch_size, regions_class_order=self.regions_class_order,
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
mixed_precision=mixed_precision)
self.network.train(current_mode)
return ret
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
if debug=True then the temporary files generated for postprocessing determination will be kept
"""
current_mode = self.network.training
self.network.eval()
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
if self.dataset_val is None:
self.load_dataset()
self.do_split()
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
# predictions as they come from the network go here
output_folder = join(self.output_folder, validation_folder_name)
maybe_mkdir_p(output_folder)
# this is for debug purposes
my_input_args = {'do_mirroring': do_mirroring,
'use_sliding_window': use_sliding_window,
'step_size': step_size,
'save_softmax': save_softmax,
'use_gaussian': use_gaussian,
'overwrite': overwrite,
'validation_folder_name': validation_folder_name,
'debug': debug,
'all_in_gpu': all_in_gpu,
'segmentation_export_kwargs': segmentation_export_kwargs,
}
save_json(my_input_args, join(output_folder, "validation_args.json"))
if do_mirroring:
if not self.data_aug_params['do_mirror']:
raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
mirror_axes = self.data_aug_params['mirror_axes']
else:
mirror_axes = ()
pred_gt_tuples = []
export_pool = Pool(default_num_threads)
results = []
for k in self.dataset_val.keys():
properties = load_pickle(self.dataset[k]['properties_file'])
fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
(save_softmax and not isfile(join(output_folder, fname + ".npz"))):
data = np.load(self.dataset[k]['data_file'])['data']
print(k, data.shape)
data[-1][data[-1] == -1] = 0
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size,
use_gaussian=use_gaussian,
all_in_gpu=all_in_gpu,
mixed_precision=self.fp16)[1]
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if save_softmax:
softmax_fname = join(output_folder, fname + ".npz")
else:
softmax_fname = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(join(output_folder, fname + ".npy"), softmax_pred)
softmax_pred = join(output_folder, fname + ".npy")
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_pred, join(output_folder, fname + ".nii.gz"),
properties, interpolation_order, self.regions_class_order,
None, None,
softmax_fname, None, force_separate_z,
interpolation_order_z),
)
)
)
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
join(self.gt_niftis_folder, fname + ".nii.gz")])
_ = [i.get() for i in results]
self.print_to_log_file("finished prediction")
# evaluate raw predictions
self.print_to_log_file("evaluation of raw predictions")
task = self.dataset_directory.split("/")[-1]
job_name = self.experiment_name
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
json_output_file=join(output_folder, "summary.json"),
json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
json_author="Fabian",
json_task=task, num_threads=default_num_threads)
if run_postprocessing_on_folds:
# in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything
# except the largest connected component for each class. To see if this improves results, we do this for all
# classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
# have this applied during inference as well
self.print_to_log_file("determining postprocessing")
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
# after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
# They are always in that folder, even if no postprocessing as applied!
# detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
# postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
# done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
# be used later
gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
maybe_mkdir_p(gt_nifti_folder)
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
success = False
attempts = 0
e = None
while not success and attempts < 10:
try:
shutil.copy(f, gt_nifti_folder)
success = True
except OSError as e:
attempts += 1
sleep(1)
if not success:
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
if e is not None:
raise e
self.network.train(current_mode)
def run_online_evaluation(self, output, target):
with torch.no_grad():
num_classes = output.shape[1]
output_softmax = softmax_helper(output)
output_seg = output_softmax.argmax(1)
target = target[:, 0]
axes = tuple(range(1, len(target.shape)))
tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
for c in range(1, num_classes):
tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
self.online_eval_tp.append(list(tp_hard))
self.online_eval_fp.append(list(fp_hard))
self.online_eval_fn.append(list(fn_hard))
def finish_online_evaluation(self):
self.online_eval_tp = np.sum(self.online_eval_tp, 0)
self.online_eval_fp = np.sum(self.online_eval_fp, 0)
self.online_eval_fn = np.sum(self.online_eval_fn, 0)
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)]
if not np.isnan(i)]
self.all_val_eval_metrics.append(np.mean(global_dc_per_class))
self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class))
self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not "
"exact.)")
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
def save_checkpoint(self, fname, save_optimizer=True):
super(Trainer_lung, self).save_checkpoint(fname, save_optimizer)
info = OrderedDict()
info['init'] = self.init_args
info['name'] = self.__class__.__name__
info['class'] = str(self.__class__)
info['plans'] = self.plans
write_pickle(info, fname + ".pkl")
================================================
FILE: unetr_pp/training/network_training/Trainer_synapse.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import time
from collections import OrderedDict
from multiprocessing import Pool
from time import sleep
from typing import Tuple, List
import matplotlib
import unetr_pp
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.postprocessing.connected_components import determine_postprocessing
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
default_2D_augmentation_params, get_default_augmentation, get_patch_size
from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset
from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss
from unetr_pp.training.network_training.network_trainer_synapse import NetworkTrainer_synapse
from unetr_pp.utilities.nd_softmax import softmax_helper
from unetr_pp.utilities.tensor_utilities import sum_tensor
from torch import nn
from torch.optim import lr_scheduler
matplotlib.use("agg")
class Trainer_synapse(NetworkTrainer_synapse):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
"""
:param deterministic:
:param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
None if you wish to load some checkpoint and do inference only
:param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
:param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
not the entire path). This is where the preprocessed data lies that will be used for network training. We made
this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
Can be None if you are doing inference only.
:param output_folder: where to store parameters, plot progress and to the validation
:param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
because the split information is stored in this directory. For running prediction only this input is not
required and may be set to None
:param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
batch is a pseudo volume?
:param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
specified for training:
if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
:param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
is considerably slower! Running unpack_data=False with 2d should never be done!
IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args
in your init accordingly. Otherwise checkpoints won't load properly!
"""
super(Trainer_synapse, self).__init__(deterministic, fp16)
self.unpack_data = unpack_data
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
# set through arguments from init
self.stage = stage
self.experiment_name = self.__class__.__name__
self.plans_file = plans_file
self.output_folder = output_folder
self.dataset_directory = dataset_directory
self.output_folder_base = self.output_folder
self.fold = fold
self.plans = None
# if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
# irrelevant
if self.dataset_directory is not None and isdir(self.dataset_directory):
self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
else:
self.gt_niftis_folder = None
self.folder_with_preprocessed_data = None
# set in self.initialize()
self.dl_tr = self.dl_val = None
self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file
self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
self.batch_dice = batch_dice
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
self.min_region_size_per_class = self.min_size_per_class = None
self.inference_pad_border_mode = "constant"
self.inference_pad_kwargs = {'constant_values': 0}
self.update_fold(fold)
self.pad_all_sides = None
self.lr_scheduler_eps = 1e-3
self.lr_scheduler_patience = 30
self.initial_lr = 3e-4
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.conv_per_stage = None
self.regions_class_order = None
def update_fold(self, fold):
"""
used to swap between folds for inference (ensemble of models from cross-validation)
DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS
:param fold:
:return:
"""
if fold is not None:
if isinstance(fold, str):
assert fold == "all", "if self.fold is a string then it must be \'all\'"
if self.output_folder.endswith("%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "%s" % str(fold))
else:
if self.output_folder.endswith("fold_%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "fold_%s" % str(fold))
self.fold = fold
def setup_DA_params(self):
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
def initialize(self, training=True, force_load_plans=False):
"""
For prediction of test cases just set training=False, this will prevent loading of training data and
training batchgenerator initialization
:param training:
:return:
"""
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
if training:
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
self.print_to_log_file("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
self.print_to_log_file("done")
else:
self.print_to_log_file(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
# assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
self.was_initialized = True
def initialize_network(self):
"""
This is specific to the U-Net and must be adapted for other network architectures
:return:
"""
# self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
# self.print_to_log_file(self.net_conv_kernel_sizes)
net_numpool = len(self.net_num_pool_op_kernel_sizes)
if self.threeD:
conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.InstanceNorm3d
else:
conv_op = nn.Conv2d
dropout_op = nn.Dropout2d
norm_op = nn.InstanceNorm2d
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs,
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
self.network.inference_apply_nonlin = softmax_helper
if torch.cuda.is_available():
self.network.cuda()
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
amsgrad=True)
self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
patience=self.lr_scheduler_patience,
verbose=True, threshold=self.lr_scheduler_eps,
threshold_mode="abs")
def plot_network_architecture(self):
try:
from batchgenerators.utilities.file_and_folder_operations import join
import hiddenlayer as hl
if torch.cuda.is_available():
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(),
transforms=None)
else:
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)),
transforms=None)
g.save(join(self.output_folder, "network_architecture.pdf"))
del g
except Exception as e:
self.print_to_log_file("Unable to plot network architecture:")
self.print_to_log_file(e)
self.print_to_log_file("\nprinting the network instead:\n")
self.print_to_log_file(self.network)
self.print_to_log_file("\n")
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def save_debug_information(self):
# saving some debug information
dct = OrderedDict()
for k in self.__dir__():
if not k.startswith("__"):
if not callable(getattr(self, k)):
dct[k] = str(getattr(self, k))
del dct['plans']
del dct['intensity_properties']
del dct['dataset']
del dct['dataset_tr']
del dct['dataset_val']
save_json(dct, join(self.output_folder, "debug.json"))
import shutil
shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl"))
def run_training(self):
self.save_debug_information()
super(Trainer_synapse, self).run_training()
def load_plans_file(self):
"""
This is what actually configures the entire experiment. The plans file is generated by experiment planning
:return:
"""
self.plans = load_pickle(self.plans_file)
def process_plans(self, plans):
if self.stage is None:
assert len(list(plans['plans_per_stage'].keys())) == 1, \
"If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \
"case. Please specify which stage of the cascade must be trained"
self.stage = list(plans['plans_per_stage'].keys())[0]
self.plans = plans
stage_plans = self.plans['plans_per_stage'][self.stage]
self.batch_size = stage_plans['batch_size']
self.net_pool_per_axis = stage_plans['num_pool_per_axis']
self.patch_size = np.array(stage_plans['patch_size']).astype(int)
self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug']
if 'pool_op_kernel_sizes' not in stage_plans.keys():
assert 'num_pool_per_axis' in stage_plans.keys()
self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...")
self.net_num_pool_op_kernel_sizes = []
for i in range(max(self.net_pool_per_axis)):
curr = []
for j in self.net_pool_per_axis:
if (max(self.net_pool_per_axis) - j) <= i:
curr.append(2)
else:
curr.append(1)
self.net_num_pool_op_kernel_sizes.append(curr)
else:
self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
if 'conv_kernel_sizes' not in stage_plans.keys():
self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...")
self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1)
else:
self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
self.pad_all_sides = None # self.patch_size
self.intensity_properties = plans['dataset_properties']['intensityproperties']
self.normalization_schemes = plans['normalization_schemes']
self.base_num_features = plans['base_num_features']
self.num_input_channels = plans['num_modalities']
self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes
self.classes = plans['all_classes']
self.use_mask_for_norm = plans['use_mask_for_norm']
self.only_keep_largest_connected_component = plans['keep_only_largest_region']
self.min_region_size_per_class = plans['min_region_size_per_class']
self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class']
if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None:
print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. "
"You should rerun preprocessing. We will proceed and assume that both transpose_foward "
"and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!")
plans['transpose_forward'] = [0, 1, 2]
plans['transpose_backward'] = [0, 1, 2]
self.transpose_forward = plans['transpose_forward']
self.transpose_backward = plans['transpose_backward']
if len(self.patch_size) == 2:
self.threeD = False
elif len(self.patch_size) == 3:
self.threeD = True
else:
raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size))
if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently
self.conv_per_stage = plans['conv_per_stage']
else:
self.conv_per_stage = 2
def load_dataset(self):
self.dataset = load_dataset(self.folder_with_preprocessed_data)
def get_basic_generators(self):
self.load_dataset()
self.do_split()
if self.threeD:
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
False, oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
else:
dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
return dl_tr, dl_val
def preprocess_patient(self, input_files):
"""
Used to predict new unseen data. Not used for the preprocessing of the training/test data
:param input_files:
:return:
"""
from unetr_pp.training.model_restore import recursive_find_python_class
preprocessor_name = self.plans.get('preprocessor_name')
if preprocessor_name is None:
if self.threeD:
preprocessor_name = "GenericPreprocessor"
else:
preprocessor_name = "PreprocessorFor2D"
print("using preprocessor", preprocessor_name)
preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")],
preprocessor_name,
current_module="unetr_pp.preprocessing")
assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \
preprocessor_name
preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm,
self.transpose_forward, self.intensity_properties)
d, s, properties = preprocessor.preprocess_test_case(input_files,
self.plans['plans_per_stage'][self.stage][
'current_spacing'])
return d, s, properties
def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None,
softmax_ouput_file: str = None, mixed_precision: bool = True) -> None:
"""
Use this to predict new data
:param input_files:
:param output_file:
:param softmax_ouput_file:
:param mixed_precision:
:return:
"""
print("preprocessing...")
d, s, properties = self.preprocess_patient(input_files)
print("predicting...")
pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
mirror_axes=self.data_aug_params['mirror_axes'],
use_sliding_window=True, step_size=0.5,
use_gaussian=True, pad_border_mode='constant',
pad_kwargs={'constant_values': 0},
verbose=True, all_in_gpu=False,
mixed_precision=mixed_precision)[1]
pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
print("resampling to original spacing and nifti export...")
save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order,
self.regions_class_order, None, None, softmax_ouput_file,
None, force_separate_z=force_separate_z,
interpolation_order_z=interpolation_order_z)
print("done")
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
:param data:
:param do_mirroring:
:param mirror_axes:
:param use_sliding_window:
:param step_size:
:param use_gaussian:
:param pad_border_mode:
:param pad_kwargs:
:param all_in_gpu:
:param verbose:
:return:
"""
if pad_border_mode == 'constant' and pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
if do_mirroring and mirror_axes is None:
mirror_axes = self.data_aug_params['mirror_axes']
if do_mirroring:
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
"was done without mirroring"
valid = list((SegmentationNetwork, nn.DataParallel))
assert isinstance(self.network, tuple(valid))
current_mode = self.network.training
self.network.eval()
ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window, step_size=step_size,
patch_size=self.patch_size, regions_class_order=self.regions_class_order,
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
mixed_precision=mixed_precision)
self.network.train(current_mode)
return ret
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
if debug=True then the temporary files generated for postprocessing determination will be kept
"""
current_mode = self.network.training
self.network.eval()
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
if self.dataset_val is None:
self.load_dataset()
self.do_split()
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
# predictions as they come from the network go here
output_folder = join(self.output_folder, validation_folder_name)
maybe_mkdir_p(output_folder)
# this is for debug purposes
my_input_args = {'do_mirroring': do_mirroring,
'use_sliding_window': use_sliding_window,
'step_size': step_size,
'save_softmax': save_softmax,
'use_gaussian': use_gaussian,
'overwrite': overwrite,
'validation_folder_name': validation_folder_name,
'debug': debug,
'all_in_gpu': all_in_gpu,
'segmentation_export_kwargs': segmentation_export_kwargs,
}
save_json(my_input_args, join(output_folder, "validation_args.json"))
if do_mirroring:
if not self.data_aug_params['do_mirror']:
raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
mirror_axes = self.data_aug_params['mirror_axes']
else:
mirror_axes = ()
pred_gt_tuples = []
export_pool = Pool(default_num_threads)
results = []
inference_times = []
for k in self.dataset_val.keys():
properties = load_pickle(self.dataset[k]['properties_file'])
fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
(save_softmax and not isfile(join(output_folder, fname + ".npz"))):
start = time.time()
data = np.load(self.dataset[k]['data_file'])['data']
print(k, data.shape)
data[-1][data[-1] == -1] = 0
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size,
use_gaussian=use_gaussian,
all_in_gpu=all_in_gpu,
mixed_precision=self.fp16)[1]
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
end = time.time()
print(f"Inference Time (k={k}): {end - start} sec.")
inference_times.append(end-start)
if save_softmax:
softmax_fname = join(output_folder, fname + ".npz")
else:
softmax_fname = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(join(output_folder, fname + ".npy"), softmax_pred)
softmax_pred = join(output_folder, fname + ".npy")
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_pred, join(output_folder, fname + ".nii.gz"),
properties, interpolation_order, self.regions_class_order,
None, None,
softmax_fname, None, force_separate_z,
interpolation_order_z),
)
)
)
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
join(self.gt_niftis_folder, fname + ".nii.gz")])
# Print inference times
print(f"Total Images: {len(inference_times)}")
print(f"Total Time: {sum(inference_times)} sec.")
print(f"Time per Image: {sum(inference_times) / len(inference_times)} sec.")
_ = [i.get() for i in results]
self.print_to_log_file("finished prediction")
# evaluate raw predictions
self.print_to_log_file("evaluation of raw predictions")
task = self.dataset_directory.split("/")[-1]
job_name = self.experiment_name
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
json_output_file=join(output_folder, "summary.json"),
json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
json_author="Fabian",
json_task=task, num_threads=default_num_threads)
if run_postprocessing_on_folds:
# in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything
# except the largest connected component for each class. To see if this improves results, we do this for all
# classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
# have this applied during inference as well
self.print_to_log_file("determining postprocessing")
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
# after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
# They are always in that folder, even if no postprocessing as applied!
# detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
# postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
# done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
# be used later
gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
maybe_mkdir_p(gt_nifti_folder)
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
success = False
attempts = 0
e = None
while not success and attempts < 10:
try:
shutil.copy(f, gt_nifti_folder)
success = True
except OSError as e:
attempts += 1
sleep(1)
if not success:
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
if e is not None:
raise e
self.network.train(current_mode)
def run_online_evaluation(self, output, target):
with torch.no_grad():
num_classes = output.shape[1]
output_softmax = softmax_helper(output)
output_seg = output_softmax.argmax(1)
target = target[:, 0]
axes = tuple(range(1, len(target.shape)))
#tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
#fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
#fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
tp_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index)
fp_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index)
fn_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index)
i=0
for c in range(1, num_classes):
if c in [10,12,13,5,9]:
continue
tp_hard[:, i] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
fp_hard[:, i] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
fn_hard[:, i] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
i+=1
tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
self.online_eval_tp.append(list(tp_hard))
self.online_eval_fp.append(list(fp_hard))
self.online_eval_fn.append(list(fn_hard))
def finish_online_evaluation(self):
self.online_eval_tp = np.sum(self.online_eval_tp, 0)
self.online_eval_fp = np.sum(self.online_eval_fp, 0)
self.online_eval_fn = np.sum(self.online_eval_fn, 0)
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)]
if not np.isnan(i)]
self.all_val_eval_metrics.append(np.mean(global_dc_per_class))
self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class))
self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not "
"exact.)")
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
def save_checkpoint(self, fname, save_optimizer=True):
super(Trainer_synapse, self).save_checkpoint(fname, save_optimizer)
info = OrderedDict()
info['init'] = self.init_args
info['name'] = self.__class__.__name__
info['class'] = str(self.__class__)
info['plans'] = self.plans
write_pickle(info, fname + ".pkl")
================================================
FILE: unetr_pp/training/network_training/Trainer_tumor.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import time
from collections import OrderedDict
from multiprocessing import Pool
from time import sleep
from typing import Tuple, List
import matplotlib
import unetr_pp
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.configuration import default_num_threads
from unetr_pp.evaluation.evaluator import aggregate_scores
from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax
from unetr_pp.network_architecture.generic_UNet import Generic_UNet
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.postprocessing.connected_components import determine_postprocessing
from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
default_2D_augmentation_params, get_default_augmentation, get_patch_size
from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset
from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss
from unetr_pp.training.network_training.network_trainer_tumor import NetworkTrainer_tumor
from unetr_pp.utilities.nd_softmax import softmax_helper
from unetr_pp.utilities.tensor_utilities import sum_tensor
from torch import nn
from torch.optim import lr_scheduler
matplotlib.use("agg")
class Trainer_tumor(NetworkTrainer_tumor):
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
"""
:param deterministic:
:param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
None if you wish to load some checkpoint and do inference only
:param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
:param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
not the entire path). This is where the preprocessed data lies that will be used for network training. We made
this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
Can be None if you are doing inference only.
:param output_folder: where to store parameters, plot progress and to the validation
:param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
because the split information is stored in this directory. For running prediction only this input is not
required and may be set to None
:param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
batch is a pseudo volume?
:param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
specified for training:
if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
:param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
is considerably slower! Running unpack_data=False with 2d should never be done!
IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args
in your init accordingly. Otherwise checkpoints won't load properly!
"""
super(Trainer_tumor, self).__init__(deterministic, fp16)
self.unpack_data = unpack_data
self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
# set through arguments from init
self.stage = stage
self.experiment_name = self.__class__.__name__
self.plans_file = plans_file
self.output_folder = output_folder
self.dataset_directory = dataset_directory
self.output_folder_base = self.output_folder
self.fold = fold
self.plans = None
# if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
# irrelevant
if self.dataset_directory is not None and isdir(self.dataset_directory):
self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
else:
self.gt_niftis_folder = None
self.folder_with_preprocessed_data = None
# set in self.initialize()
self.dl_tr = self.dl_val = None
self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file
self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
self.batch_dice = batch_dice
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
self.min_region_size_per_class = self.min_size_per_class = None
self.inference_pad_border_mode = "constant"
self.inference_pad_kwargs = {'constant_values': 0}
self.update_fold(fold)
self.pad_all_sides = None
self.lr_scheduler_eps = 1e-3
self.lr_scheduler_patience = 30
self.initial_lr = 3e-4
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.conv_per_stage = None
self.regions_class_order = None
def update_fold(self, fold):
"""
used to swap between folds for inference (ensemble of models from cross-validation)
DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS
:param fold:
:return:
"""
if fold is not None:
if isinstance(fold, str):
assert fold == "all", "if self.fold is a string then it must be \'all\'"
if self.output_folder.endswith("%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "%s" % str(fold))
else:
if self.output_folder.endswith("fold_%s" % str(self.fold)):
self.output_folder = self.output_folder_base
self.output_folder = join(self.output_folder, "fold_%s" % str(fold))
self.fold = fold
def setup_DA_params(self):
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
def initialize(self, training=True, force_load_plans=False):
"""
For prediction of test cases just set training=False, this will prevent loading of training data and
training batchgenerator initialization
:param training:
:return:
"""
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.process_plans(self.plans)
self.setup_DA_params()
if training:
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
"_stage%d" % self.stage)
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
self.print_to_log_file("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
self.print_to_log_file("done")
else:
self.print_to_log_file(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
# assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
self.was_initialized = True
def initialize_network(self):
"""
This is specific to the U-Net and must be adapted for other network architectures
:return:
"""
# self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
# self.print_to_log_file(self.net_conv_kernel_sizes)
net_numpool = len(self.net_num_pool_op_kernel_sizes)
if self.threeD:
conv_op = nn.Conv3d
dropout_op = nn.Dropout3d
norm_op = nn.InstanceNorm3d
else:
conv_op = nn.Conv2d
dropout_op = nn.Dropout2d
norm_op = nn.InstanceNorm2d
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
dropout_op_kwargs = {'p': 0, 'inplace': True}
net_nonlin = nn.LeakyReLU
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs,
net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
self.network.inference_apply_nonlin = softmax_helper
if torch.cuda.is_available():
self.network.cuda()
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
amsgrad=True)
self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
patience=self.lr_scheduler_patience,
verbose=True, threshold=self.lr_scheduler_eps,
threshold_mode="abs")
def plot_network_architecture(self):
try:
from batchgenerators.utilities.file_and_folder_operations import join
import hiddenlayer as hl
if torch.cuda.is_available():
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(),
transforms=None)
else:
g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)),
transforms=None)
g.save(join(self.output_folder, "network_architecture.pdf"))
del g
except Exception as e:
self.print_to_log_file("Unable to plot network architecture:")
self.print_to_log_file(e)
self.print_to_log_file("\nprinting the network instead:\n")
self.print_to_log_file(self.network)
self.print_to_log_file("\n")
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def save_debug_information(self):
# saving some debug information
dct = OrderedDict()
for k in self.__dir__():
if not k.startswith("__"):
if not callable(getattr(self, k)):
dct[k] = str(getattr(self, k))
del dct['plans']
del dct['intensity_properties']
del dct['dataset']
del dct['dataset_tr']
del dct['dataset_val']
save_json(dct, join(self.output_folder, "debug.json"))
import shutil
shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl"))
def run_training(self):
self.save_debug_information()
super(Trainer_tumor, self).run_training()
def load_plans_file(self):
"""
This is what actually configures the entire experiment. The plans file is generated by experiment planning
:return:
"""
self.plans = load_pickle(self.plans_file)
#self.plans = load_pickle("../../../DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/Task003_tumor/unetr_pp_Plansv2.1_plans_3D.pkl")
def process_plans(self, plans):
if self.stage is None:
assert len(list(plans['plans_per_stage'].keys())) == 1, \
"If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \
"case. Please specify which stage of the cascade must be trained"
self.stage = list(plans['plans_per_stage'].keys())[0]
self.plans = plans
stage_plans = self.plans['plans_per_stage'][self.stage]
self.batch_size = stage_plans['batch_size']
self.net_pool_per_axis = stage_plans['num_pool_per_axis']
self.patch_size = np.array(stage_plans['patch_size']).astype(int)
self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug']
if 'pool_op_kernel_sizes' not in stage_plans.keys():
assert 'num_pool_per_axis' in stage_plans.keys()
self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...")
self.net_num_pool_op_kernel_sizes = []
for i in range(max(self.net_pool_per_axis)):
curr = []
for j in self.net_pool_per_axis:
if (max(self.net_pool_per_axis) - j) <= i:
curr.append(2)
else:
curr.append(1)
self.net_num_pool_op_kernel_sizes.append(curr)
else:
self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
if 'conv_kernel_sizes' not in stage_plans.keys():
self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...")
self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1)
else:
self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
self.pad_all_sides = None # self.patch_size
self.intensity_properties = plans['dataset_properties']['intensityproperties']
self.normalization_schemes = plans['normalization_schemes']
self.base_num_features = plans['base_num_features']
self.num_input_channels = plans['num_modalities']
self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes
self.classes = plans['all_classes']
self.use_mask_for_norm = plans['use_mask_for_norm']
self.only_keep_largest_connected_component = plans['keep_only_largest_region']
self.min_region_size_per_class = plans['min_region_size_per_class']
self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class']
if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None:
print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. "
"You should rerun preprocessing. We will proceed and assume that both transpose_foward "
"and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!")
plans['transpose_forward'] = [0, 1, 2]
plans['transpose_backward'] = [0, 1, 2]
self.transpose_forward = plans['transpose_forward']
self.transpose_backward = plans['transpose_backward']
if len(self.patch_size) == 2:
self.threeD = False
elif len(self.patch_size) == 3:
self.threeD = True
else:
raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size))
if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently
self.conv_per_stage = plans['conv_per_stage']
else:
self.conv_per_stage = 2
def load_dataset(self):
self.dataset = load_dataset(self.folder_with_preprocessed_data)
def get_basic_generators(self):
self.load_dataset()
self.do_split()
if self.threeD:
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
False, oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
else:
dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
return dl_tr, dl_val
def preprocess_patient(self, input_files):
"""
Used to predict new unseen data. Not used for the preprocessing of the training/test data
:param input_files:
:return:
"""
from unetr_pp.training.model_restore import recursive_find_python_class
preprocessor_name = self.plans.get('preprocessor_name')
if preprocessor_name is None:
if self.threeD:
preprocessor_name = "GenericPreprocessor"
else:
preprocessor_name = "PreprocessorFor2D"
print("using preprocessor", preprocessor_name)
preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")],
preprocessor_name,
current_module="unetr_pp.preprocessing")
assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \
preprocessor_name
preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm,
self.transpose_forward, self.intensity_properties)
d, s, properties = preprocessor.preprocess_test_case(input_files,
self.plans['plans_per_stage'][self.stage][
'current_spacing'])
return d, s, properties
def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None,
softmax_ouput_file: str = None, mixed_precision: bool = True) -> None:
"""
Use this to predict new data
:param input_files:
:param output_file:
:param softmax_ouput_file:
:param mixed_precision:
:return:
"""
print("preprocessing...")
d, s, properties = self.preprocess_patient(input_files)
print("predicting...")
pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
mirror_axes=self.data_aug_params['mirror_axes'],
use_sliding_window=True, step_size=0.5,
use_gaussian=True, pad_border_mode='constant',
pad_kwargs={'constant_values': 0},
verbose=True, all_in_gpu=False,
mixed_precision=mixed_precision)[1]
pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
print("resampling to original spacing and nifti export...")
save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order,
self.regions_class_order, None, None, softmax_ouput_file,
None, force_separate_z=force_separate_z,
interpolation_order_z=interpolation_order_z)
print("done")
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
:param data:
:param do_mirroring:
:param mirror_axes:
:param use_sliding_window:
:param step_size:
:param use_gaussian:
:param pad_border_mode:
:param pad_kwargs:
:param all_in_gpu:
:param verbose:
:return:
"""
if pad_border_mode == 'constant' and pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
if do_mirroring and mirror_axes is None:
mirror_axes = self.data_aug_params['mirror_axes']
if do_mirroring:
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
"was done without mirroring"
valid = list((SegmentationNetwork, nn.DataParallel))
assert isinstance(self.network, tuple(valid))
current_mode = self.network.training
self.network.eval()
ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window, step_size=step_size,
patch_size=self.patch_size, regions_class_order=self.regions_class_order,
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
mixed_precision=mixed_precision)
self.network.train(current_mode)
return ret
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
if debug=True then the temporary files generated for postprocessing determination will be kept
"""
current_mode = self.network.training
self.network.eval()
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
if self.dataset_val is None:
self.load_dataset()
self.do_split()
if segmentation_export_kwargs is None:
if 'segmentation_export_params' in self.plans.keys():
force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
else:
force_separate_z = None
interpolation_order = 1
interpolation_order_z = 0
else:
force_separate_z = segmentation_export_kwargs['force_separate_z']
interpolation_order = segmentation_export_kwargs['interpolation_order']
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
# predictions as they come from the network go here
output_folder = join(self.output_folder, validation_folder_name)
maybe_mkdir_p(output_folder)
# this is for debug purposes
my_input_args = {'do_mirroring': do_mirroring,
'use_sliding_window': use_sliding_window,
'step_size': step_size,
'save_softmax': save_softmax,
'use_gaussian': use_gaussian,
'overwrite': overwrite,
'validation_folder_name': validation_folder_name,
'debug': debug,
'all_in_gpu': all_in_gpu,
'segmentation_export_kwargs': segmentation_export_kwargs,
}
save_json(my_input_args, join(output_folder, "validation_args.json"))
if do_mirroring:
if not self.data_aug_params['do_mirror']:
raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
mirror_axes = self.data_aug_params['mirror_axes']
else:
mirror_axes = ()
pred_gt_tuples = []
export_pool = Pool(default_num_threads)
results = []
inference_times = []
for k in self.dataset_val.keys():
properties = load_pickle(self.dataset[k]['properties_file'])
fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
(save_softmax and not isfile(join(output_folder, fname + ".npz"))):
start = time.time()
data = np.load(self.dataset[k]['data_file'])['data']
print(k, data.shape)
data[-1][data[-1] == -1] = 0
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size,
use_gaussian=use_gaussian,
all_in_gpu=all_in_gpu,
mixed_precision=self.fp16)[1]
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
end = time.time()
print(f"Inference Time (k={k}): {end - start} sec.")
inference_times.append(end-start)
if save_softmax:
softmax_fname = join(output_folder, fname + ".npz")
else:
softmax_fname = None
"""There is a problem with python process communication that prevents us from communicating obejcts
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
filename or np.ndarray and will handle this automatically"""
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
np.save(join(output_folder, fname + ".npy"), softmax_pred)
softmax_pred = join(output_folder, fname + ".npy")
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
((softmax_pred, join(output_folder, fname + ".nii.gz"),
properties, interpolation_order, self.regions_class_order,
None, None,
softmax_fname, None, force_separate_z,
interpolation_order_z),
)
)
)
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
join(self.gt_niftis_folder, fname + ".nii.gz")])
# Print inference times
print(f"Total Images: {len(inference_times)}")
print(f"Total Time: {sum(inference_times)} sec.")
print(f"Time per Image: {sum(inference_times) / len(inference_times)} sec.")
_ = [i.get() for i in results]
self.print_to_log_file("finished prediction")
# evaluate raw predictions
self.print_to_log_file("evaluation of raw predictions")
task = self.dataset_directory.split("/")[-1]
job_name = self.experiment_name
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
json_output_file=join(output_folder, "summary.json"),
json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
json_author="Fabian",
json_task=task, num_threads=default_num_threads)
if run_postprocessing_on_folds:
# in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything
# except the largest connected component for each class. To see if this improves results, we do this for all
# classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
# have this applied during inference as well
self.print_to_log_file("determining postprocessing")
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
# after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
# They are always in that folder, even if no postprocessing as applied!
# detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
# postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
# done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
# be used later
gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
maybe_mkdir_p(gt_nifti_folder)
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
success = False
attempts = 0
e = None
while not success and attempts < 10:
try:
shutil.copy(f, gt_nifti_folder)
success = True
except OSError as e:
attempts += 1
sleep(1)
if not success:
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
if e is not None:
raise e
self.network.train(current_mode)
def run_online_evaluation(self, output, target):
with torch.no_grad():
num_classes = output.shape[1]
output_softmax = softmax_helper(output)
output_seg = output_softmax.argmax(1)
target = target[:, 0]
axes = tuple(range(1, len(target.shape)))
tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
for c in range(1, num_classes):
tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
self.online_eval_tp.append(list(tp_hard))
self.online_eval_fp.append(list(fp_hard))
self.online_eval_fn.append(list(fn_hard))
def finish_online_evaluation(self):
self.online_eval_tp = np.sum(self.online_eval_tp, 0)
self.online_eval_fp = np.sum(self.online_eval_fp, 0)
self.online_eval_fn = np.sum(self.online_eval_fn, 0)
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)]
if not np.isnan(i)]
self.all_val_eval_metrics.append(np.mean(global_dc_per_class))
self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class))
self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not "
"exact.)")
self.online_eval_foreground_dc = []
self.online_eval_tp = []
self.online_eval_fp = []
self.online_eval_fn = []
def save_checkpoint(self, fname, save_optimizer=True):
super(Trainer_tumor, self).save_checkpoint(fname, save_optimizer)
info = OrderedDict()
info['init'] = self.init_args
info['name'] = self.__class__.__name__
info['class'] = str(self.__class__)
info['plans'] = self.plans
write_pickle(info, fname + ".pkl")
================================================
FILE: unetr_pp/training/network_training/network_trainer_acdc.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _warnings import warn
from typing import Tuple
import matplotlib
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import _LRScheduler
matplotlib.use("agg")
from time import time, sleep
import torch
import numpy as np
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import sys
from collections import OrderedDict
import torch.backends.cudnn as cudnn
from abc import abstractmethod
from datetime import datetime
from tqdm import trange
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
class NetworkTrainer_acdc(object):
def __init__(self, deterministic=True, fp16=False):
"""
A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
as the training loop, tracking of training and validation losses (and the target metric if you implement it)
Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
results.
What you need to override:
- __init__
- initialize
- run_online_evaluation (optional)
- finish_online_evaluation (optional)
- validate
- predict_test_case
"""
self.fp16 = fp16
self.amp_grad_scaler = None
if deterministic:
np.random.seed(12345)
torch.manual_seed(12345)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(12345)
cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
################# SET THESE IN self.initialize() ###################################
self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
self.optimizer = None
self.lr_scheduler = None
self.tr_gen = self.val_gen = None
self.was_initialized = False
################# SET THESE IN INIT ################################################
self.output_folder = None
self.fold = None
self.loss = None
self.dataset_directory = None
################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
self.dataset = None # these can be None for inference mode
self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
self.patience = 50
self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new
# if this is too low then the moving average will be too noisy and the training may terminate early. If it is
# too high the training will take forever
self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new
self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller)
self.max_num_epochs = 1000
self.num_batches_per_epoch = 250
self.num_val_batches_per_epoch = 50
self.also_val_in_tr_mode = False
self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold
################# LEAVE THESE ALONE ################################################
self.val_eval_criterion_MA = None
self.train_loss_MA = None
self.best_val_eval_criterion_MA = None
self.best_MA_tr_loss_for_patience = None
self.best_epoch_based_on_MA_tr_loss = None
self.all_tr_losses = []
self.all_val_losses = []
self.all_val_losses_tr_mode = []
self.all_val_eval_metrics = [] # does not have to be used
self.epoch = 0
self.log_file = None
self.deterministic = deterministic
self.use_progress_bar = False
if 'nnformer_use_progress_bar' in os.environ.keys():
self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar']))
################# Settings for saving checkpoints ##################################
self.save_every = 50
self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
# time an intermediate checkpoint is created
self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
self.save_final_checkpoint = True # whether or not to save the final checkpoint
@abstractmethod
def initialize(self, training=True):
"""
create self.output_folder
modify self.output_folder if you are doing cross-validation (one folder per fold)
set self.tr_gen and self.val_gen
call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
finally set self.was_initialized to True
:param training:
:return:
"""
@abstractmethod
def load_dataset(self):
pass
def do_split(self):
"""
This is a suggestion for if your dataset is a dictionary (my personal standard)
:return:
"""
splits_file = join(self.dataset_directory, "splits_final.pkl")
if not isfile(splits_file):
self.print_to_log_file("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
splits = load_pickle(splits_file)
if self.fold == "all":
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def plot_progress(self):
"""
Should probably by improved
:return:
"""
try:
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr")
ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False")
if len(self.all_val_losses_tr_mode) > 0:
ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True")
if len(self.all_val_eval_metrics) == len(x_values):
ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax2.set_ylabel("evaluation metric")
ax.legend()
ax2.legend(loc=9)
fig.savefig(join(self.output_folder, "progress.png"))
plt.close()
'''
#below is et ..dice
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et")
ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc")
ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt")
ax.set_xlabel("epoch")
ax.set_ylabel("dice")
ax.legend()
fig.savefig(join(self.output_folder, "dice_progress.png"))
plt.close()
'''
except IOError:
self.print_to_log_file("failed to plot: ", sys.exc_info())
def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)
if add_timestamp:
args = ("%s:" % dt_object, *args)
if self.log_file is None:
maybe_mkdir_p(self.output_folder)
timestamp = datetime.now()
self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
with open(self.log_file, 'w') as f:
f.write("Starting... \n")
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
sleep(0.5)
ctr += 1
if also_print_to_console:
print(*args)
def save_checkpoint(self, fname, save_optimizer=True):
start_time = time()
state_dict = self.network.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
lr_sched_state_dct = None
if self.lr_scheduler is not None and hasattr(self.lr_scheduler,
'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
lr_sched_state_dct = self.lr_scheduler.state_dict()
# WTF is this!?
# for key in lr_sched_state_dct.keys():
# lr_sched_state_dct[key] = lr_sched_state_dct[key]
if save_optimizer:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
self.print_to_log_file("saving checkpoint...")
save_this = {
'epoch': self.epoch + 1,
'state_dict': state_dict,
'optimizer_state_dict': optimizer_state_dict,
'lr_scheduler_state_dict': lr_sched_state_dct,
'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode,
self.all_val_eval_metrics),
'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)}
if self.amp_grad_scaler is not None:
save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()
torch.save(save_this, fname)
self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time))
def load_best_checkpoint(self, train=True):
if self.fold is None:
raise RuntimeError("Cannot load best checkpoint if self.fold is None")
if isfile(join(self.output_folder, "model_best.model")):
self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train)
else:
self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
"back to load_latest_checkpoint")
self.load_latest_checkpoint(train)
def load_latest_checkpoint(self, train=True):
if isfile(join(self.output_folder, "model_final_checkpoint.model")):
return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train)
if isfile(join(self.output_folder, "model_latest.model")):
return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train)
if isfile(join(self.output_folder, "model_best.model")):
return self.load_best_checkpoint(train)
raise RuntimeError("No checkpoint found")
def load_final_checkpoint(self, train=False):
filename = join(self.output_folder, "model_final_checkpoint.model")
if not isfile(filename):
raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename)
return self.load_checkpoint(filename, train=train)
def load_checkpoint(self, fname, train=True):
self.print_to_log_file("loading checkpoint", fname, "train=", train)
if not self.was_initialized:
self.initialize(train)
# saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
saved_model = torch.load(fname, map_location=torch.device('cpu'))
self.load_checkpoint_ram(saved_model, train)
@abstractmethod
def initialize_network(self):
"""
initialize self.network here
:return:
"""
pass
@abstractmethod
def initialize_optimizer_and_scheduler(self):
"""
initialize self.optimizer and self.lr_scheduler (if applicable) here
:return:
"""
pass
def load_checkpoint_ram(self, checkpoint, train=True):
"""
used for if the checkpoint is already in ram
:param checkpoint:
:param train:
:return:
"""
print('I am here !!!')
if not self.was_initialized:
self.initialize(train)
new_state_dict = OrderedDict()
curr_state_dict_keys = list(self.network.state_dict().keys())
# if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
for k, value in checkpoint['state_dict'].items():
key = k
if key not in curr_state_dict_keys and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
if self.fp16:
self._maybe_init_amp()
if 'amp_grad_scaler' in checkpoint.keys():
self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
self.network.load_state_dict(new_state_dict)
self.epoch = checkpoint['epoch']
if train:
optimizer_state_dict = checkpoint['optimizer_state_dict']
if optimizer_state_dict is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
'lr_scheduler_state_dict'] is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
if issubclass(self.lr_scheduler.__class__, _LRScheduler):
self.lr_scheduler.step(self.epoch)
self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
'plot_stuff']
# load best loss (if present)
print('best_stuff' in checkpoint.keys())
if 'best_stuff' in checkpoint.keys():
self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
'best_stuff']
# after the training is done, the epoch is incremented one more time in my old code. This results in
# self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
# len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
if self.epoch != len(self.all_tr_losses):
self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
"due to an old bug and should only appear when you are loading old models. New "
"models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
self.epoch = len(self.all_tr_losses)
self.all_tr_losses = self.all_tr_losses[:self.epoch]
self.all_val_losses = self.all_val_losses[:self.epoch]
self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
self._maybe_init_amp()
def _maybe_init_amp(self):
if self.fp16 and self.amp_grad_scaler is None:
self.amp_grad_scaler = GradScaler()
def plot_network_architecture(self):
"""
can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger
assumptions on the presence of class variables
:return:
"""
pass
def run_training(self):
if not torch.cuda.is_available():
self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
_ = self.tr_gen.next()
_ = self.val_gen.next()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._maybe_init_amp()
maybe_mkdir_p(self.output_folder)
self.plot_network_architecture()
if cudnn.benchmark and cudnn.deterministic:
warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
"But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
"If you want deterministic then set benchmark=False")
if not self.was_initialized:
self.initialize(True)
while self.epoch < self.max_num_epochs:
self.print_to_log_file("\nepoch: ", self.epoch)
epoch_start_time = time()
train_losses_epoch = []
# train one epoch
self.network.train()
if self.use_progress_bar:
with trange(self.num_batches_per_epoch) as tbar:
for b in tbar:
tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
l = self.run_iteration(self.tr_gen, True)
tbar.set_postfix(loss=l)
train_losses_epoch.append(l)
else:
for _ in range(self.num_batches_per_epoch):
l = self.run_iteration(self.tr_gen, True)
train_losses_epoch.append(l)
self.all_tr_losses.append(np.mean(train_losses_epoch))
self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
with torch.no_grad():
# validation with train=False
self.network.eval()
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False, True)
val_losses.append(l)
self.all_val_losses.append(np.mean(val_losses))
self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
if self.also_val_in_tr_mode:
self.network.train()
# validation with train=True
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False)
val_losses.append(l)
self.all_val_losses_tr_mode.append(np.mean(val_losses))
self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
self.update_train_loss_MA() # needed for lr scheduler and stopping of training
continue_training = self.on_epoch_end()
epoch_end_time = time()
if not continue_training:
# allows for early stopping
break
self.epoch += 1
self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
# now we can delete latest as it will be identical with final
if isfile(join(self.output_folder, "model_latest.model")):
os.remove(join(self.output_folder, "model_latest.model"))
if isfile(join(self.output_folder, "model_latest.model.pkl")):
os.remove(join(self.output_folder, "model_latest.model.pkl"))
def maybe_update_lr(self):
# maybe update learning rate
if self.lr_scheduler is not None:
assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.val_eval_criterion_MA)
else:
self.lr_scheduler.step(self.epoch + 1)
self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
def maybe_save_checkpoint(self):
"""
Saves a checkpoint every save_ever epochs.
:return:
"""
if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)):
self.print_to_log_file("saving scheduled checkpoint file...")
if not self.save_latest_only:
self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1)))
self.save_checkpoint(join(self.output_folder, "model_latest.model"))
self.print_to_log_file("done")
def update_eval_criterion_MA(self):
"""
If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
(not a minimization, but a maximization of a metric and therefore the - in the latter case)
:return:
"""
if self.val_eval_criterion_MA is None:
if len(self.all_val_eval_metrics) == 0:
self.val_eval_criterion_MA = - self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
else:
if len(self.all_val_eval_metrics) == 0:
"""
We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
is better, so we need to negate it.
"""
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
1 - self.val_eval_criterion_alpha) * \
self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
1 - self.val_eval_criterion_alpha) * \
self.all_val_eval_metrics[-1]
def manage_patience(self):
# update patience
continue_training = True
if self.patience is not None:
# if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
# initialize them
if self.best_MA_tr_loss_for_patience is None:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
if self.best_epoch_based_on_MA_tr_loss is None:
self.best_epoch_based_on_MA_tr_loss = self.epoch
if self.best_val_eval_criterion_MA is None:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
# check if the current epoch is the best one according to moving average of validation criterion. If so
# then save 'best' model
# Do not use this for validation. This is intended for test set prediction only.
self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)
if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
#self.print_to_log_file("saving best epoch checkpoint...")
if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model"))
# Now see if the moving average of the train loss has improved. If yes then reset patience, else
# increase patience
if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
self.best_epoch_based_on_MA_tr_loss = self.epoch
#self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
else:
pass
#self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
# (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))
# if patience has reached its maximum then finish training (provided lr is low enough)
if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
#self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
else:
#self.print_to_log_file("My patience ended")
continue_training = False
else:
pass
#self.print_to_log_file(
# "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))
return continue_training
def on_epoch_end(self):
self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_
# metrics
self.plot_progress()
self.maybe_save_checkpoint()
self.update_eval_criterion_MA()
self.maybe_update_lr()
continue_training = self.manage_patience()
return continue_training
def update_train_loss_MA(self):
if self.train_loss_MA is None:
self.train_loss_MA = self.all_tr_losses[-1]
else:
self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
self.all_tr_losses[-1]
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def run_online_evaluation(self, *args, **kwargs):
"""
Can be implemented, does not have to
:param output_torch:
:param target_npy:
:return:
"""
pass
def finish_online_evaluation(self):
"""
Can be implemented, does not have to
:return:
"""
pass
@abstractmethod
def validate(self, *args, **kwargs):
pass
def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98):
"""
stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
:param num_iters:
:param init_value:
:param final_value:
:param beta:
:return:
"""
import math
self._maybe_init_amp()
mult = (final_value / init_value) ** (1 / num_iters)
lr = init_value
self.optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num in range(1, num_iters + 1):
# +1 because this one here is not designed to have negative loss...
loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1
# Compute the smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta ** batch_num)
# Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
break
# Record the best loss
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss
# Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
# Update the lr for the next step
lr *= mult
self.optimizer.param_groups[0]['lr'] = lr
import matplotlib.pyplot as plt
lrs = [10 ** i for i in log_lrs]
fig = plt.figure()
plt.xscale('log')
plt.plot(lrs[10:-5], losses[10:-5])
plt.savefig(join(self.output_folder, "lr_finder.png"))
plt.close()
return log_lrs, losses
================================================
FILE: unetr_pp/training/network_training/network_trainer_lung.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _warnings import warn
from typing import Tuple
import matplotlib
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import _LRScheduler
matplotlib.use("agg")
from time import time, sleep
import torch
import numpy as np
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import sys
from collections import OrderedDict
import torch.backends.cudnn as cudnn
from abc import abstractmethod
from datetime import datetime
from tqdm import trange
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
class NetworkTrainer_lung(object):
def __init__(self, deterministic=True, fp16=False):
"""
A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
as the training loop, tracking of training and validation losses (and the target metric if you implement it)
Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
results.
What you need to override:
- __init__
- initialize
- run_online_evaluation (optional)
- finish_online_evaluation (optional)
- validate
- predict_test_case
"""
self.fp16 = fp16
self.amp_grad_scaler = None
if deterministic:
np.random.seed(12345)
torch.manual_seed(12345)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(12345)
cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
################# SET THESE IN self.initialize() ###################################
self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
self.optimizer = None
self.lr_scheduler = None
self.tr_gen = self.val_gen = None
self.was_initialized = False
################# SET THESE IN INIT ################################################
self.output_folder = None
self.fold = None
self.loss = None
self.dataset_directory = None
################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
self.dataset = None # these can be None for inference mode
self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
self.patience = 50
self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new
# if this is too low then the moving average will be too noisy and the training may terminate early. If it is
# too high the training will take forever
self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new
self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller)
self.max_num_epochs = 1000
self.num_batches_per_epoch = 250
self.num_val_batches_per_epoch = 50
self.also_val_in_tr_mode = False
self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold
################# LEAVE THESE ALONE ################################################
self.val_eval_criterion_MA = None
self.train_loss_MA = None
self.best_val_eval_criterion_MA = None
self.best_MA_tr_loss_for_patience = None
self.best_epoch_based_on_MA_tr_loss = None
self.all_tr_losses = []
self.all_val_losses = []
self.all_val_losses_tr_mode = []
self.all_val_eval_metrics = [] # does not have to be used
self.epoch = 0
self.log_file = None
self.deterministic = deterministic
self.use_progress_bar = False
if 'nnformer_use_progress_bar' in os.environ.keys():
self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar']))
################# Settings for saving checkpoints ##################################
self.save_every = 50
self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
# time an intermediate checkpoint is created
self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
self.save_final_checkpoint = True # whether or not to save the final checkpoint
@abstractmethod
def initialize(self, training=True):
"""
create self.output_folder
modify self.output_folder if you are doing cross-validation (one folder per fold)
set self.tr_gen and self.val_gen
call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
finally set self.was_initialized to True
:param training:
:return:
"""
@abstractmethod
def load_dataset(self):
pass
def do_split(self):
"""
This is a suggestion for if your dataset is a dictionary (my personal standard)
:return:
"""
splits_file = join(self.dataset_directory, "splits_final.pkl")
if not isfile(splits_file):
self.print_to_log_file("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
splits = load_pickle(splits_file)
if self.fold == "all":
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def plot_progress(self):
"""
Should probably by improved
:return:
"""
try:
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr")
ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False")
if len(self.all_val_losses_tr_mode) > 0:
ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True")
if len(self.all_val_eval_metrics) == len(x_values):
ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax2.set_ylabel("evaluation metric")
ax.legend()
ax2.legend(loc=9)
fig.savefig(join(self.output_folder, "progress.png"))
plt.close()
'''
#below is et ..dice
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et")
ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc")
ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt")
ax.set_xlabel("epoch")
ax.set_ylabel("dice")
ax.legend()
fig.savefig(join(self.output_folder, "dice_progress.png"))
plt.close()
'''
except IOError:
self.print_to_log_file("failed to plot: ", sys.exc_info())
def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)
if add_timestamp:
args = ("%s:" % dt_object, *args)
if self.log_file is None:
maybe_mkdir_p(self.output_folder)
timestamp = datetime.now()
self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
with open(self.log_file, 'w') as f:
f.write("Starting... \n")
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
sleep(0.5)
ctr += 1
if also_print_to_console:
print(*args)
def save_checkpoint(self, fname, save_optimizer=True):
start_time = time()
state_dict = self.network.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
lr_sched_state_dct = None
if self.lr_scheduler is not None and hasattr(self.lr_scheduler,
'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
lr_sched_state_dct = self.lr_scheduler.state_dict()
# WTF is this!?
# for key in lr_sched_state_dct.keys():
# lr_sched_state_dct[key] = lr_sched_state_dct[key]
if save_optimizer:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
self.print_to_log_file("saving checkpoint...")
save_this = {
'epoch': self.epoch + 1,
'state_dict': state_dict,
'optimizer_state_dict': optimizer_state_dict,
'lr_scheduler_state_dict': lr_sched_state_dct,
'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode,
self.all_val_eval_metrics),
'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)}
if self.amp_grad_scaler is not None:
save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()
torch.save(save_this, fname)
self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time))
def load_best_checkpoint(self, train=True):
if self.fold is None:
raise RuntimeError("Cannot load best checkpoint if self.fold is None")
if isfile(join(self.output_folder, "model_best.model")):
self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train)
else:
self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
"back to load_latest_checkpoint")
self.load_latest_checkpoint(train)
def load_latest_checkpoint(self, train=True):
if isfile(join(self.output_folder, "model_final_checkpoint.model")):
return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train)
if isfile(join(self.output_folder, "model_latest.model")):
return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train)
if isfile(join(self.output_folder, "model_best.model")):
return self.load_best_checkpoint(train)
raise RuntimeError("No checkpoint found")
def load_final_checkpoint(self, train=False):
filename = join(self.output_folder, "model_final_checkpoint.model")
if not isfile(filename):
raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename)
return self.load_checkpoint(filename, train=train)
def load_checkpoint(self, fname, train=True):
self.print_to_log_file("loading checkpoint", fname, "train=", train)
if not self.was_initialized:
self.initialize(train)
# saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
saved_model = torch.load(fname, map_location=torch.device('cpu'))
self.load_checkpoint_ram(saved_model, train)
@abstractmethod
def initialize_network(self):
"""
initialize self.network here
:return:
"""
pass
@abstractmethod
def initialize_optimizer_and_scheduler(self):
"""
initialize self.optimizer and self.lr_scheduler (if applicable) here
:return:
"""
pass
def load_checkpoint_ram(self, checkpoint, train=True):
"""
used for if the checkpoint is already in ram
:param checkpoint:
:param train:
:return:
"""
print('I am here !!!')
if not self.was_initialized:
self.initialize(train)
new_state_dict = OrderedDict()
curr_state_dict_keys = list(self.network.state_dict().keys())
# if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
for k, value in checkpoint['state_dict'].items():
key = k
if key not in curr_state_dict_keys and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
if self.fp16:
self._maybe_init_amp()
if 'amp_grad_scaler' in checkpoint.keys():
self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
self.network.load_state_dict(new_state_dict)
self.epoch = checkpoint['epoch']
if train:
optimizer_state_dict = checkpoint['optimizer_state_dict']
if optimizer_state_dict is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
'lr_scheduler_state_dict'] is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
if issubclass(self.lr_scheduler.__class__, _LRScheduler):
self.lr_scheduler.step(self.epoch)
self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
'plot_stuff']
# load best loss (if present)
print('best_stuff' in checkpoint.keys())
if 'best_stuff' in checkpoint.keys():
self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
'best_stuff']
# after the training is done, the epoch is incremented one more time in my old code. This results in
# self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
# len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
if self.epoch != len(self.all_tr_losses):
self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
"due to an old bug and should only appear when you are loading old models. New "
"models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
self.epoch = len(self.all_tr_losses)
self.all_tr_losses = self.all_tr_losses[:self.epoch]
self.all_val_losses = self.all_val_losses[:self.epoch]
self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
self._maybe_init_amp()
def _maybe_init_amp(self):
if self.fp16 and self.amp_grad_scaler is None:
self.amp_grad_scaler = GradScaler()
def plot_network_architecture(self):
"""
can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger
assumptions on the presence of class variables
:return:
"""
pass
def run_training(self):
if not torch.cuda.is_available():
self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
_ = self.tr_gen.next()
_ = self.val_gen.next()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._maybe_init_amp()
maybe_mkdir_p(self.output_folder)
self.plot_network_architecture()
if cudnn.benchmark and cudnn.deterministic:
warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
"But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
"If you want deterministic then set benchmark=False")
if not self.was_initialized:
self.initialize(True)
while self.epoch < self.max_num_epochs:
self.print_to_log_file("\nepoch: ", self.epoch)
epoch_start_time = time()
train_losses_epoch = []
# train one epoch
self.network.train()
if self.use_progress_bar:
with trange(self.num_batches_per_epoch) as tbar:
for b in tbar:
tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
l = self.run_iteration(self.tr_gen, True)
tbar.set_postfix(loss=l)
train_losses_epoch.append(l)
else:
for _ in range(self.num_batches_per_epoch):
l = self.run_iteration(self.tr_gen, True)
train_losses_epoch.append(l)
self.all_tr_losses.append(np.mean(train_losses_epoch))
self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
with torch.no_grad():
# validation with train=False
self.network.eval()
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False, True)
val_losses.append(l)
self.all_val_losses.append(np.mean(val_losses))
self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
if self.also_val_in_tr_mode:
self.network.train()
# validation with train=True
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False)
val_losses.append(l)
self.all_val_losses_tr_mode.append(np.mean(val_losses))
self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
self.update_train_loss_MA() # needed for lr scheduler and stopping of training
continue_training = self.on_epoch_end()
epoch_end_time = time()
if not continue_training:
# allows for early stopping
break
self.epoch += 1
self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
# now we can delete latest as it will be identical with final
if isfile(join(self.output_folder, "model_latest.model")):
os.remove(join(self.output_folder, "model_latest.model"))
if isfile(join(self.output_folder, "model_latest.model.pkl")):
os.remove(join(self.output_folder, "model_latest.model.pkl"))
def maybe_update_lr(self):
# maybe update learning rate
if self.lr_scheduler is not None:
assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.val_eval_criterion_MA)
else:
self.lr_scheduler.step(self.epoch + 1)
self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
def maybe_save_checkpoint(self):
"""
Saves a checkpoint every save_ever epochs.
:return:
"""
if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)):
self.print_to_log_file("saving scheduled checkpoint file...")
if not self.save_latest_only:
self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1)))
self.save_checkpoint(join(self.output_folder, "model_latest.model"))
self.print_to_log_file("done")
def update_eval_criterion_MA(self):
"""
If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
(not a minimization, but a maximization of a metric and therefore the - in the latter case)
:return:
"""
if self.val_eval_criterion_MA is None:
if len(self.all_val_eval_metrics) == 0:
self.val_eval_criterion_MA = - self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
else:
if len(self.all_val_eval_metrics) == 0:
"""
We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
is better, so we need to negate it.
"""
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
1 - self.val_eval_criterion_alpha) * \
self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
1 - self.val_eval_criterion_alpha) * \
self.all_val_eval_metrics[-1]
def manage_patience(self):
# update patience
continue_training = True
if self.patience is not None:
# if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
# initialize them
if self.best_MA_tr_loss_for_patience is None:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
if self.best_epoch_based_on_MA_tr_loss is None:
self.best_epoch_based_on_MA_tr_loss = self.epoch
if self.best_val_eval_criterion_MA is None:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
# check if the current epoch is the best one according to moving average of validation criterion. If so
# then save 'best' model
# Do not use this for validation. This is intended for test set prediction only.
self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)
if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
#self.print_to_log_file("saving best epoch checkpoint...")
if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model"))
# Now see if the moving average of the train loss has improved. If yes then reset patience, else
# increase patience
if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
self.best_epoch_based_on_MA_tr_loss = self.epoch
#self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
else:
pass
#self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
# (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))
# if patience has reached its maximum then finish training (provided lr is low enough)
if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
#self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
else:
#self.print_to_log_file("My patience ended")
continue_training = False
else:
pass
#self.print_to_log_file(
# "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))
return continue_training
def on_epoch_end(self):
self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_
# metrics
self.plot_progress()
self.maybe_save_checkpoint()
self.update_eval_criterion_MA()
self.maybe_update_lr()
continue_training = self.manage_patience()
return continue_training
def update_train_loss_MA(self):
if self.train_loss_MA is None:
self.train_loss_MA = self.all_tr_losses[-1]
else:
self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
self.all_tr_losses[-1]
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def run_online_evaluation(self, *args, **kwargs):
"""
Can be implemented, does not have to
:param output_torch:
:param target_npy:
:return:
"""
pass
def finish_online_evaluation(self):
"""
Can be implemented, does not have to
:return:
"""
pass
@abstractmethod
def validate(self, *args, **kwargs):
pass
def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98):
"""
stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
:param num_iters:
:param init_value:
:param final_value:
:param beta:
:return:
"""
import math
self._maybe_init_amp()
mult = (final_value / init_value) ** (1 / num_iters)
lr = init_value
self.optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num in range(1, num_iters + 1):
# +1 because this one here is not designed to have negative loss...
loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1
# Compute the smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta ** batch_num)
# Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
break
# Record the best loss
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss
# Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
# Update the lr for the next step
lr *= mult
self.optimizer.param_groups[0]['lr'] = lr
import matplotlib.pyplot as plt
lrs = [10 ** i for i in log_lrs]
fig = plt.figure()
plt.xscale('log')
plt.plot(lrs[10:-5], losses[10:-5])
plt.savefig(join(self.output_folder, "lr_finder.png"))
plt.close()
return log_lrs, losses
================================================
FILE: unetr_pp/training/network_training/network_trainer_synapse.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _warnings import warn
from typing import Tuple
import matplotlib
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm
matplotlib.use("agg")
from time import time, sleep
import torch
import numpy as np
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import sys
from collections import OrderedDict
import torch.backends.cudnn as cudnn
from abc import abstractmethod
from datetime import datetime
from tqdm import trange
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
class NetworkTrainer_synapse(object):
def __init__(self, deterministic=True, fp16=False):
"""
A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
as the training loop, tracking of training and validation losses (and the target metric if you implement it)
Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
results.
What you need to override:
- __init__
- initialize
- run_online_evaluation (optional)
- finish_online_evaluation (optional)
- validate
- predict_test_case
"""
self.fp16 = fp16
self.amp_grad_scaler = None
if deterministic:
np.random.seed(12345)
torch.manual_seed(12345)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(12345)
cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
################# SET THESE IN self.initialize() ###################################
self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
self.optimizer = None
self.lr_scheduler = None
self.tr_gen = self.val_gen = None
self.was_initialized = False
################# SET THESE IN INIT ################################################
self.output_folder = None
self.fold = None
self.loss = None
self.dataset_directory = None
################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
self.dataset = None # these can be None for inference mode
self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
self.patience = 50
self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new
# if this is too low then the moving average will be too noisy and the training may terminate early. If it is
# too high the training will take forever
self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new
self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller)
self.max_num_epochs = 1000
self.num_batches_per_epoch = 250
self.num_val_batches_per_epoch = 50
self.also_val_in_tr_mode = False
self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold
################# LEAVE THESE ALONE ################################################
self.val_eval_criterion_MA = None
self.train_loss_MA = None
self.best_val_eval_criterion_MA = None
self.best_MA_tr_loss_for_patience = None
self.best_epoch_based_on_MA_tr_loss = None
self.all_tr_losses = []
self.all_val_losses = []
self.all_val_losses_tr_mode = []
self.all_val_eval_metrics = [] # does not have to be used
self.epoch = 0
self.log_file = None
self.deterministic = deterministic
self.use_progress_bar = False
if 'nnformer_use_progress_bar' in os.environ.keys():
self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar']))
################# Settings for saving checkpoints ##################################
self.save_every = 50
self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
# time an intermediate checkpoint is created
self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
self.save_final_checkpoint = True # whether or not to save the final checkpoint
@abstractmethod
def initialize(self, training=True):
"""
create self.output_folder
modify self.output_folder if you are doing cross-validation (one folder per fold)
set self.tr_gen and self.val_gen
call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
finally set self.was_initialized to True
:param training:
:return:
"""
@abstractmethod
def load_dataset(self):
pass
def do_split(self):
"""
This is a suggestion for if your dataset is a dictionary (my personal standard)
:return:
"""
splits_file = join(self.dataset_directory, "splits_final.pkl")
if not isfile(splits_file):
self.print_to_log_file("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
splits = load_pickle(splits_file)
if self.fold == "all":
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def plot_progress(self):
"""
Should probably by improved
:return:
"""
try:
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr")
ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False")
if len(self.all_val_losses_tr_mode) > 0:
ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True")
if len(self.all_val_eval_metrics) == len(x_values):
ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax2.set_ylabel("evaluation metric")
ax.legend()
ax2.legend(loc=9)
fig.savefig(join(self.output_folder, "progress.png"))
plt.close()
'''
#below is et ..dice
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et")
ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc")
ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt")
ax.set_xlabel("epoch")
ax.set_ylabel("dice")
ax.legend()
fig.savefig(join(self.output_folder, "dice_progress.png"))
plt.close()
'''
except IOError:
self.print_to_log_file("failed to plot: ", sys.exc_info())
def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)
if add_timestamp:
args = ("%s:" % dt_object, *args)
if self.log_file is None:
maybe_mkdir_p(self.output_folder)
timestamp = datetime.now()
self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
with open(self.log_file, 'w') as f:
f.write("Starting... \n")
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
sleep(0.5)
ctr += 1
if also_print_to_console:
print(*args)
def save_checkpoint(self, fname, save_optimizer=True):
start_time = time()
state_dict = self.network.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
lr_sched_state_dct = None
if self.lr_scheduler is not None and hasattr(self.lr_scheduler,
'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
lr_sched_state_dct = self.lr_scheduler.state_dict()
# WTF is this!?
# for key in lr_sched_state_dct.keys():
# lr_sched_state_dct[key] = lr_sched_state_dct[key]
if save_optimizer:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
self.print_to_log_file("saving checkpoint...")
save_this = {
'epoch': self.epoch + 1,
'state_dict': state_dict,
'optimizer_state_dict': optimizer_state_dict,
'lr_scheduler_state_dict': lr_sched_state_dct,
'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode,
self.all_val_eval_metrics),
'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)}
if self.amp_grad_scaler is not None:
save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()
torch.save(save_this, fname)
self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time))
def load_best_checkpoint(self, train=True):
if self.fold is None:
raise RuntimeError("Cannot load best checkpoint if self.fold is None")
if isfile(join(self.output_folder, "model_best.model")):
self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train)
else:
self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
"back to load_latest_checkpoint")
self.load_latest_checkpoint(train)
def load_latest_checkpoint(self, train=True):
if isfile(join(self.output_folder, "model_final_checkpoint.model")):
return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train)
if isfile(join(self.output_folder, "model_latest.model")):
return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train)
if isfile(join(self.output_folder, "model_best.model")):
return self.load_best_checkpoint(train)
raise RuntimeError("No checkpoint found")
def load_final_checkpoint(self, train=False):
filename = join(self.output_folder, "model_final_checkpoint.model")
if not isfile(filename):
raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename)
return self.load_checkpoint(filename, train=train)
def load_checkpoint(self, fname, train=True):
self.print_to_log_file("loading checkpoint", fname, "train=", train)
if not self.was_initialized:
self.initialize(train)
# saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
saved_model = torch.load(fname, map_location=torch.device('cpu'))
self.load_checkpoint_ram(saved_model, train)
@abstractmethod
def initialize_network(self):
"""
initialize self.network here
:return:
"""
pass
@abstractmethod
def initialize_optimizer_and_scheduler(self):
"""
initialize self.optimizer and self.lr_scheduler (if applicable) here
:return:
"""
pass
def load_checkpoint_ram(self, checkpoint, train=True):
"""
used for if the checkpoint is already in ram
:param checkpoint:
:param train:
:return:
"""
print('I am here !!!')
if not self.was_initialized:
self.initialize(train)
new_state_dict = OrderedDict()
curr_state_dict_keys = list(self.network.state_dict().keys())
# if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
for k, value in checkpoint['state_dict'].items():
key = k
if key not in curr_state_dict_keys and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
if self.fp16:
self._maybe_init_amp()
if 'amp_grad_scaler' in checkpoint.keys():
self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
self.network.load_state_dict(new_state_dict)
self.epoch = checkpoint['epoch']
if train:
optimizer_state_dict = checkpoint['optimizer_state_dict']
if optimizer_state_dict is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
'lr_scheduler_state_dict'] is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
if issubclass(self.lr_scheduler.__class__, _LRScheduler):
self.lr_scheduler.step(self.epoch)
self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
'plot_stuff']
# load best loss (if present)
print('best_stuff' in checkpoint.keys())
if 'best_stuff' in checkpoint.keys():
self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
'best_stuff']
# after the training is done, the epoch is incremented one more time in my old code. This results in
# self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
# len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
if self.epoch != len(self.all_tr_losses):
self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
"due to an old bug and should only appear when you are loading old models. New "
"models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
self.epoch = len(self.all_tr_losses)
self.all_tr_losses = self.all_tr_losses[:self.epoch]
self.all_val_losses = self.all_val_losses[:self.epoch]
self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
self._maybe_init_amp()
def _maybe_init_amp(self):
if self.fp16 and self.amp_grad_scaler is None:
self.amp_grad_scaler = GradScaler()
def plot_network_architecture(self):
"""
can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger
assumptions on the presence of class variables
:return:
"""
pass
def run_training(self):
if not torch.cuda.is_available():
self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
_ = self.tr_gen.next()
_ = self.val_gen.next()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._maybe_init_amp()
maybe_mkdir_p(self.output_folder)
self.plot_network_architecture()
if cudnn.benchmark and cudnn.deterministic:
warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
"But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
"If you want deterministic then set benchmark=False")
if not self.was_initialized:
self.initialize(True)
while self.epoch < self.max_num_epochs:
self.print_to_log_file("\nepoch: ", self.epoch)
epoch_start_time = time()
train_losses_epoch = []
# train one epoch
self.network.train()
if self.use_progress_bar:
with trange(self.num_batches_per_epoch) as tbar:
for b in tbar:
tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
l = self.run_iteration(self.tr_gen, True)
tbar.set_postfix(loss=l)
train_losses_epoch.append(l)
else:
for _ in range(self.num_batches_per_epoch):
l = self.run_iteration(self.tr_gen, True)
train_losses_epoch.append(l)
self.all_tr_losses.append(np.mean(train_losses_epoch))
self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
with torch.no_grad():
# validation with train=False
self.network.eval()
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False, True)
val_losses.append(l)
self.all_val_losses.append(np.mean(val_losses))
self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
if self.also_val_in_tr_mode:
self.network.train()
# validation with train=True
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False)
val_losses.append(l)
self.all_val_losses_tr_mode.append(np.mean(val_losses))
self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
self.update_train_loss_MA() # needed for lr scheduler and stopping of training
continue_training = self.on_epoch_end()
epoch_end_time = time()
if not continue_training:
# allows for early stopping
break
self.epoch += 1
self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
# now we can delete latest as it will be identical with final
if isfile(join(self.output_folder, "model_latest.model")):
os.remove(join(self.output_folder, "model_latest.model"))
if isfile(join(self.output_folder, "model_latest.model.pkl")):
os.remove(join(self.output_folder, "model_latest.model.pkl"))
def maybe_update_lr(self):
# maybe update learning rate
if self.lr_scheduler is not None:
assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.val_eval_criterion_MA)
else:
self.lr_scheduler.step(self.epoch + 1)
self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
def maybe_save_checkpoint(self):
"""
Saves a checkpoint every save_ever epochs.
:return:
"""
if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)):
self.print_to_log_file("saving scheduled checkpoint file...")
if not self.save_latest_only:
self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1)))
self.save_checkpoint(join(self.output_folder, "model_latest.model"))
self.print_to_log_file("done")
def update_eval_criterion_MA(self):
"""
If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
(not a minimization, but a maximization of a metric and therefore the - in the latter case)
:return:
"""
if self.val_eval_criterion_MA is None:
if len(self.all_val_eval_metrics) == 0:
self.val_eval_criterion_MA = - self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
else:
if len(self.all_val_eval_metrics) == 0:
"""
We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
is better, so we need to negate it.
"""
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
1 - self.val_eval_criterion_alpha) * \
self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
1 - self.val_eval_criterion_alpha) * \
self.all_val_eval_metrics[-1]
def manage_patience(self):
# update patience
continue_training = True
if self.patience is not None:
# if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
# initialize them
if self.best_MA_tr_loss_for_patience is None:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
if self.best_epoch_based_on_MA_tr_loss is None:
self.best_epoch_based_on_MA_tr_loss = self.epoch
if self.best_val_eval_criterion_MA is None:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
# check if the current epoch is the best one according to moving average of validation criterion. If so
# then save 'best' model
# Do not use this for validation. This is intended for test set prediction only.
self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)
if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
#self.print_to_log_file("saving best epoch checkpoint...")
if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model"))
# Now see if the moving average of the train loss has improved. If yes then reset patience, else
# increase patience
if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
self.best_epoch_based_on_MA_tr_loss = self.epoch
#self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
else:
pass
#self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
# (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))
# if patience has reached its maximum then finish training (provided lr is low enough)
if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
#self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
else:
#self.print_to_log_file("My patience ended")
continue_training = False
else:
pass
#self.print_to_log_file(
# "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))
return continue_training
def on_epoch_end(self):
self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_
# metrics
self.plot_progress()
self.maybe_save_checkpoint()
self.update_eval_criterion_MA()
self.maybe_update_lr()
continue_training = self.manage_patience()
return continue_training
def update_train_loss_MA(self):
if self.train_loss_MA is None:
self.train_loss_MA = self.all_tr_losses[-1]
else:
self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
self.all_tr_losses[-1]
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def run_online_evaluation(self, *args, **kwargs):
"""
Can be implemented, does not have to
:param output_torch:
:param target_npy:
:return:
"""
pass
def finish_online_evaluation(self):
"""
Can be implemented, does not have to
:return:
"""
pass
@abstractmethod
def validate(self, *args, **kwargs):
pass
def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98):
"""
stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
:param num_iters:
:param init_value:
:param final_value:
:param beta:
:return:
"""
import math
self._maybe_init_amp()
mult = (final_value / init_value) ** (1 / num_iters)
lr = init_value
self.optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num in tqdm(range(1, num_iters + 1)):
# +1 because this one here is not designed to have negative loss...
loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False) + 1
# Compute the smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta ** batch_num)
# Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
break
# Record the best loss
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss
# Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
# Update the lr for the next step
lr *= mult
self.optimizer.param_groups[0]['lr'] = lr
import matplotlib.pyplot as plt
lrs = [10 ** i for i in log_lrs]
fig = plt.figure()
plt.xscale('log')
plt.plot(lrs[10:-5], losses[10:-5])
plt.savefig(join(self.output_folder, "lr_finder.png"))
plt.close()
import numpy as np
index = np.argmin(losses)
print(f"The best LR is {lrs[index]}")
return log_lrs, losses
================================================
FILE: unetr_pp/training/network_training/network_trainer_tumor.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _warnings import warn
from typing import Tuple
import matplotlib
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import _LRScheduler
matplotlib.use("agg")
from time import time, sleep
import torch
import numpy as np
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import sys
from collections import OrderedDict
import torch.backends.cudnn as cudnn
from abc import abstractmethod
from datetime import datetime
from tqdm import trange
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
class NetworkTrainer_tumor(object):
def __init__(self, deterministic=True, fp16=False):
"""
A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
as the training loop, tracking of training and validation losses (and the target metric if you implement it)
Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
results.
What you need to override:
- __init__
- initialize
- run_online_evaluation (optional)
- finish_online_evaluation (optional)
- validate
- predict_test_case
"""
self.fp16 = fp16
self.amp_grad_scaler = None
if deterministic:
np.random.seed(12345)
torch.manual_seed(12345)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(12345)
cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
################# SET THESE IN self.initialize() ###################################
self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
self.optimizer = None
self.lr_scheduler = None
self.tr_gen = self.val_gen = None
self.was_initialized = False
################# SET THESE IN INIT ################################################
self.output_folder = None
self.fold = None
self.loss = None
self.dataset_directory = None
################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
self.dataset = None # these can be None for inference mode
self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
self.patience = 50
self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new
# if this is too low then the moving average will be too noisy and the training may terminate early. If it is
# too high the training will take forever
self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new
self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller)
self.max_num_epochs = 1000
self.num_batches_per_epoch = 250
self.num_val_batches_per_epoch = 50
self.also_val_in_tr_mode = False
self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold
################# LEAVE THESE ALONE ################################################
self.val_eval_criterion_MA = None
self.train_loss_MA = None
self.best_val_eval_criterion_MA = None
self.best_MA_tr_loss_for_patience = None
self.best_epoch_based_on_MA_tr_loss = None
self.all_tr_losses = []
self.all_val_losses = []
self.all_val_losses_tr_mode = []
self.all_val_eval_metrics = [] # does not have to be used
self.epoch = 0
self.log_file = None
self.deterministic = deterministic
self.use_progress_bar = False
if 'nnformer_use_progress_bar' in os.environ.keys():
self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar']))
################# Settings for saving checkpoints ##################################
self.save_every = 50
self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
# time an intermediate checkpoint is created
self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
self.save_final_checkpoint = True # whether or not to save the final checkpoint
@abstractmethod
def initialize(self, training=True):
"""
create self.output_folder
modify self.output_folder if you are doing cross-validation (one folder per fold)
set self.tr_gen and self.val_gen
call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
finally set self.was_initialized to True
:param training:
:return:
"""
@abstractmethod
def load_dataset(self):
pass
def do_split(self):
"""
This is a suggestion for if your dataset is a dictionary (my personal standard)
:return:
"""
splits_file = join(self.dataset_directory, "splits_final.pkl")
if not isfile(splits_file):
self.print_to_log_file("Creating new split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
splits = load_pickle(splits_file)
if self.fold == "all":
tr_keys = val_keys = list(self.dataset.keys())
else:
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def plot_progress(self):
"""
Should probably by improved
:return:
"""
try:
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr")
ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False")
if len(self.all_val_losses_tr_mode) > 0:
ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True")
if len(self.all_val_eval_metrics) == len(x_values):
ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax2.set_ylabel("evaluation metric")
ax.legend()
ax2.legend(loc=9)
fig.savefig(join(self.output_folder, "progress.png"))
plt.close()
'''
#below is et ..dice
font = {'weight': 'normal',
'size': 18}
matplotlib.rc('font', **font)
fig = plt.figure(figsize=(30, 24))
ax = fig.add_subplot(111)
ax2 = ax.twinx()
x_values = list(range(self.epoch + 1))
ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et")
ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc")
ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt")
ax.set_xlabel("epoch")
ax.set_ylabel("dice")
ax.legend()
fig.savefig(join(self.output_folder, "dice_progress.png"))
plt.close()
'''
except IOError:
self.print_to_log_file("failed to plot: ", sys.exc_info())
def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)
if add_timestamp:
args = ("%s:" % dt_object, *args)
if self.log_file is None:
maybe_mkdir_p(self.output_folder)
timestamp = datetime.now()
self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
with open(self.log_file, 'w') as f:
f.write("Starting... \n")
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
sleep(0.5)
ctr += 1
if also_print_to_console:
print(*args)
def save_checkpoint(self, fname, save_optimizer=True):
start_time = time()
state_dict = self.network.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
lr_sched_state_dct = None
if self.lr_scheduler is not None and hasattr(self.lr_scheduler,
'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
lr_sched_state_dct = self.lr_scheduler.state_dict()
# WTF is this!?
# for key in lr_sched_state_dct.keys():
# lr_sched_state_dct[key] = lr_sched_state_dct[key]
if save_optimizer:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
self.print_to_log_file("saving checkpoint...")
save_this = {
'epoch': self.epoch + 1,
'state_dict': state_dict,
'optimizer_state_dict': optimizer_state_dict,
'lr_scheduler_state_dict': lr_sched_state_dct,
'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode,
self.all_val_eval_metrics),
'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)}
if self.amp_grad_scaler is not None:
save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()
torch.save(save_this, fname)
self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time))
def load_best_checkpoint(self, train=True):
if self.fold is None:
raise RuntimeError("Cannot load best checkpoint if self.fold is None")
if isfile(join(self.output_folder, "model_best.model")):
self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train)
else:
self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
"back to load_latest_checkpoint")
self.load_latest_checkpoint(train)
def load_latest_checkpoint(self, train=True):
if isfile(join(self.output_folder, "model_final_checkpoint.model")):
return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train)
if isfile(join(self.output_folder, "model_latest.model")):
return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train)
if isfile(join(self.output_folder, "model_best.model")):
return self.load_best_checkpoint(train)
raise RuntimeError("No checkpoint found")
def load_final_checkpoint(self, train=False):
filename = join(self.output_folder, "model_final_checkpoint.model")
if not isfile(filename):
raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename)
return self.load_checkpoint(filename, train=train)
def load_checkpoint(self, fname, train=True):
self.print_to_log_file("loading checkpoint", fname, "train=", train)
if not self.was_initialized:
self.initialize(train)
# saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
saved_model = torch.load(fname, map_location=torch.device('cpu'))
self.load_checkpoint_ram(saved_model, train)
@abstractmethod
def initialize_network(self):
"""
initialize self.network here
:return:
"""
pass
@abstractmethod
def initialize_optimizer_and_scheduler(self):
"""
initialize self.optimizer and self.lr_scheduler (if applicable) here
:return:
"""
pass
def load_checkpoint_ram(self, checkpoint, train=True):
"""
used for if the checkpoint is already in ram
:param checkpoint:
:param train:
:return:
"""
print('I am here !!!')
if not self.was_initialized:
self.initialize(train)
new_state_dict = OrderedDict()
curr_state_dict_keys = list(self.network.state_dict().keys())
# if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
for k, value in checkpoint['state_dict'].items():
key = k
if key not in curr_state_dict_keys and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
if self.fp16:
self._maybe_init_amp()
if 'amp_grad_scaler' in checkpoint.keys():
self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
self.network.load_state_dict(new_state_dict)
self.epoch = checkpoint['epoch']
if train:
optimizer_state_dict = checkpoint['optimizer_state_dict']
if optimizer_state_dict is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
'lr_scheduler_state_dict'] is not None:
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
if issubclass(self.lr_scheduler.__class__, _LRScheduler):
self.lr_scheduler.step(self.epoch)
self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
'plot_stuff']
# load best loss (if present)
print('best_stuff' in checkpoint.keys())
if 'best_stuff' in checkpoint.keys():
self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
'best_stuff']
# after the training is done, the epoch is incremented one more time in my old code. This results in
# self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
# len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
if self.epoch != len(self.all_tr_losses):
self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
"due to an old bug and should only appear when you are loading old models. New "
"models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
self.epoch = len(self.all_tr_losses)
self.all_tr_losses = self.all_tr_losses[:self.epoch]
self.all_val_losses = self.all_val_losses[:self.epoch]
self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
self._maybe_init_amp()
def _maybe_init_amp(self):
if self.fp16 and self.amp_grad_scaler is None:
self.amp_grad_scaler = GradScaler()
def plot_network_architecture(self):
"""
can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger
assumptions on the presence of class variables
:return:
"""
pass
def run_training(self):
if not torch.cuda.is_available():
self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
_ = self.tr_gen.next()
_ = self.val_gen.next()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._maybe_init_amp()
maybe_mkdir_p(self.output_folder)
self.plot_network_architecture()
if cudnn.benchmark and cudnn.deterministic:
warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
"But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
"If you want deterministic then set benchmark=False")
if not self.was_initialized:
self.initialize(True)
while self.epoch < self.max_num_epochs:
self.print_to_log_file("\nepoch: ", self.epoch)
epoch_start_time = time()
train_losses_epoch = []
# train one epoch
self.network.train()
if self.use_progress_bar:
with trange(self.num_batches_per_epoch) as tbar:
for b in tbar:
tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
l = self.run_iteration(self.tr_gen, True)
tbar.set_postfix(loss=l)
train_losses_epoch.append(l)
else:
for _ in range(self.num_batches_per_epoch):
l = self.run_iteration(self.tr_gen, True)
train_losses_epoch.append(l)
self.all_tr_losses.append(np.mean(train_losses_epoch))
self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
with torch.no_grad():
# validation with train=False
self.network.eval()
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False, True)
val_losses.append(l)
self.all_val_losses.append(np.mean(val_losses))
self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
if self.also_val_in_tr_mode:
self.network.train()
# validation with train=True
val_losses = []
for b in range(self.num_val_batches_per_epoch):
l = self.run_iteration(self.val_gen, False)
val_losses.append(l)
self.all_val_losses_tr_mode.append(np.mean(val_losses))
self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
self.update_train_loss_MA() # needed for lr scheduler and stopping of training
continue_training = self.on_epoch_end()
epoch_end_time = time()
if not continue_training:
# allows for early stopping
break
self.epoch += 1
self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
# now we can delete latest as it will be identical with final
if isfile(join(self.output_folder, "model_latest.model")):
os.remove(join(self.output_folder, "model_latest.model"))
if isfile(join(self.output_folder, "model_latest.model.pkl")):
os.remove(join(self.output_folder, "model_latest.model.pkl"))
def maybe_update_lr(self):
# maybe update learning rate
if self.lr_scheduler is not None:
assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.val_eval_criterion_MA)
else:
self.lr_scheduler.step(self.epoch + 1)
self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
def maybe_save_checkpoint(self):
"""
Saves a checkpoint every save_ever epochs.
:return:
"""
if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)):
self.print_to_log_file("saving scheduled checkpoint file...")
if not self.save_latest_only:
self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1)))
self.save_checkpoint(join(self.output_folder, "model_latest.model"))
self.print_to_log_file("done")
def update_eval_criterion_MA(self):
"""
If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
(not a minimization, but a maximization of a metric and therefore the - in the latter case)
:return:
"""
if self.val_eval_criterion_MA is None:
if len(self.all_val_eval_metrics) == 0:
self.val_eval_criterion_MA = - self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
else:
if len(self.all_val_eval_metrics) == 0:
"""
We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
is better, so we need to negate it.
"""
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
1 - self.val_eval_criterion_alpha) * \
self.all_val_losses[-1]
else:
self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
1 - self.val_eval_criterion_alpha) * \
self.all_val_eval_metrics[-1]
def manage_patience(self):
# update patience
continue_training = True
if self.patience is not None:
# if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
# initialize them
if self.best_MA_tr_loss_for_patience is None:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
if self.best_epoch_based_on_MA_tr_loss is None:
self.best_epoch_based_on_MA_tr_loss = self.epoch
if self.best_val_eval_criterion_MA is None:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
# check if the current epoch is the best one according to moving average of validation criterion. If so
# then save 'best' model
# Do not use this for validation. This is intended for test set prediction only.
self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)
if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
#self.print_to_log_file("saving best epoch checkpoint...")
if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model"))
# Now see if the moving average of the train loss has improved. If yes then reset patience, else
# increase patience
if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
self.best_MA_tr_loss_for_patience = self.train_loss_MA
self.best_epoch_based_on_MA_tr_loss = self.epoch
#self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
else:
pass
#self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
# (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))
# if patience has reached its maximum then finish training (provided lr is low enough)
if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
#self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
else:
#self.print_to_log_file("My patience ended")
continue_training = False
else:
pass
#self.print_to_log_file(
# "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))
return continue_training
def on_epoch_end(self):
self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_
# metrics
self.plot_progress()
self.maybe_save_checkpoint()
self.update_eval_criterion_MA()
self.maybe_update_lr()
continue_training = self.manage_patience()
return continue_training
def update_train_loss_MA(self):
if self.train_loss_MA is None:
self.train_loss_MA = self.all_tr_losses[-1]
else:
self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
self.all_tr_losses[-1]
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def run_online_evaluation(self, *args, **kwargs):
"""
Can be implemented, does not have to
:param output_torch:
:param target_npy:
:return:
"""
pass
def finish_online_evaluation(self):
"""
Can be implemented, does not have to
:return:
"""
pass
@abstractmethod
def validate(self, *args, **kwargs):
pass
def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98):
"""
stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
:param num_iters:
:param init_value:
:param final_value:
:param beta:
:return:
"""
import math
self._maybe_init_amp()
mult = (final_value / init_value) ** (1 / num_iters)
lr = init_value
self.optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
losses = []
log_lrs = []
for batch_num in range(1, num_iters + 1):
# +1 because this one here is not designed to have negative loss...
loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1
# Compute the smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta ** batch_num)
# Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
break
# Record the best loss
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss
# Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
# Update the lr for the next step
lr *= mult
self.optimizer.param_groups[0]['lr'] = lr
import matplotlib.pyplot as plt
lrs = [10 ** i for i in log_lrs]
fig = plt.figure()
plt.xscale('log')
plt.plot(lrs[10:-5], losses[10:-5])
plt.savefig(join(self.output_folder, "lr_finder.png"))
plt.close()
return log_lrs, losses
================================================
FILE: unetr_pp/training/network_training/unetr_pp_trainer_acdc.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Tuple
import numpy as np
import torch
from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
get_patch_size, default_3D_augmentation_params
from unetr_pp.training.dataloading.dataset_loading import unpack_dataset
from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc
from unetr_pp.utilities.nd_softmax import softmax_helper
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import autocast
from unetr_pp.training.learning_rate.poly_lr import poly_lr
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.acdc.unetr_pp_acdc import UNETR_PP
from fvcore.nn import FlopCountAnalysis
class unetr_pp_trainer_acdc(Trainer_acdc):
"""
Info for Fabian: same as internal nnUNetTrainerV2_2
"""
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
self.max_num_epochs = 1000
self.initial_lr = 1e-2
self.deep_supervision_scales = None
self.ds_loss_weights = None
self.pin_memory = True
self.load_pretrain_weight = False
self.load_plans_file()
if len(self.plans['plans_per_stage']) == 2:
Stage = 1
else:
Stage = 0
self.crop_size = self.plans['plans_per_stage'][Stage]['patch_size']
self.input_channels = self.plans['num_modalities']
self.num_classes = self.plans['num_classes'] + 1
self.conv_op = nn.Conv3d
self.embedding_dim = 96
self.depths = [2, 2, 2, 2]
self.num_heads = [3, 6, 12, 24]
self.embedding_patch_size = [1, 4, 4]
self.window_size = [[3, 5, 5], [3, 5, 5], [7, 10, 10], [3, 5, 5]]
self.down_stride = [[1, 4, 4], [1, 8, 8], [2, 16, 16], [4, 32, 32]]
self.deep_supervision = True
def initialize(self, training=True, force_load_plans=False):
"""
- replaced get_default_augmentation with get_moreDA_augmentation
- enforce to only run this code once
- loss function wrapper for deep supervision
:param training:
:param force_load_plans:
:return:
"""
if not self.was_initialized:
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.plans['plans_per_stage'][0]['patch_size'] = np.array([16, 160, 160])
self.crop_size = np.array([16, 160, 160])
self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[1, 4, 4], [2, 2, 2], [2, 2, 2]]
self.process_plans(self.plans)
self.setup_DA_params()
if self.deep_supervision:
################# Here we wrap the loss for deep supervision ############
# we need to know the number of outputs of the network
net_numpool = len(self.net_num_pool_op_kernel_sizes)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
# mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
# weights[~mask] = 0
weights = weights / weights.sum()
print(weights)
self.ds_loss_weights = weights
# now wrap the loss
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
################# END ###################
self.folder_with_preprocessed_data = join(self.dataset_directory,
self.plans['data_identifier'] + "_stage%d" % self.stage)
seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
if training:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_moreDA_augmentation(
self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params,
deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None,
pin_memory=self.pin_memory,
use_nondetMultiThreadedAugmenter=False,
seeds_train=seeds_train,
seeds_val=seeds_val
)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
else:
self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
self.was_initialized = True
def initialize_network(self):
"""
- momentum 0.99
- SGD instead of Adam
- self.lr_scheduler = None because we do poly_lr
- deep supervision = True
- i am sure I forgot something here
Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
:return:
"""
self.network = UNETR_PP(in_channels=self.input_channels,
out_channels=self.num_classes,
feature_size=16,
num_heads=4,
depths=[3, 3, 3, 3],
dims=[32, 64, 128, 256],
do_ds=True,
)
if torch.cuda.is_available():
self.network.cuda()
self.network.inference_apply_nonlin = softmax_helper
# Print the network parameters & Flops
n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
input_res = (1, 16, 160, 160)
input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype,
device=next(self.network.parameters()).device)
flops = FlopCountAnalysis(self.network, input)
model_flops = flops.total()
print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M")
print(f"MAdds: {round(model_flops * 1e-9, 2)} G")
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
self.lr_scheduler = None
def run_online_evaluation(self, output, target):
"""
due to deep supervision the return value and the reference are now lists of tensors. We only need the full
resolution output because this is what we are interested in in the end. The others are ignored
:param output:
:param target:
:return:
"""
if self.deep_supervision:
target = target[0]
output = output[0]
else:
target = target
output = output
return super().run_online_evaluation(output, target)
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
save_softmax=save_softmax, use_gaussian=use_gaussian,
overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
run_postprocessing_on_folds=run_postprocessing_on_folds)
self.network.do_ds = ds
return ret
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision=True) -> Tuple[
np.ndarray, np.ndarray]:
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().predict_preprocessed_data_return_seg_and_softmax(data,
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size, use_gaussian=use_gaussian,
pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose,
mixed_precision=mixed_precision)
self.network.do_ds = ds
return ret
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
"""
gradient clipping improves training stability
:param data_generator:
:param do_backprop:
:param run_online_evaluation:
:return:
"""
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def do_split(self):
"""
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
Sometimes you may want to create your own split for various reasons. For this you will need to create your own
splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
use a random 80:20 data split.
:return:
"""
if self.fold == "all":
# if fold==all then we use all images for training and validation
tr_keys = val_keys = list(self.dataset.keys())
else:
splits_file = join(self.dataset_directory, "splits_final.pkl")
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
else:
self.print_to_log_file("Using splits from existing split file:", splits_file)
splits = load_pickle(splits_file)
self.print_to_log_file("The split file contains %d splits." % len(splits))
self.print_to_log_file("Desired fold for training: %d" % self.fold)
splits[self.fold]['train'] = np.array(['patient001_frame01', 'patient001_frame12', 'patient004_frame01',
'patient004_frame15', 'patient005_frame01', 'patient005_frame13',
'patient006_frame01', 'patient006_frame16', 'patient007_frame01',
'patient007_frame07', 'patient010_frame01', 'patient010_frame13',
'patient011_frame01', 'patient011_frame08', 'patient013_frame01',
'patient013_frame14', 'patient015_frame01', 'patient015_frame10',
'patient016_frame01', 'patient016_frame12', 'patient018_frame01',
'patient018_frame10', 'patient019_frame01', 'patient019_frame11',
'patient020_frame01', 'patient020_frame11', 'patient021_frame01',
'patient021_frame13', 'patient022_frame01', 'patient022_frame11',
'patient023_frame01', 'patient023_frame09', 'patient025_frame01',
'patient025_frame09', 'patient026_frame01', 'patient026_frame12',
'patient027_frame01', 'patient027_frame11', 'patient028_frame01',
'patient028_frame09', 'patient029_frame01', 'patient029_frame12',
'patient030_frame01', 'patient030_frame12', 'patient031_frame01',
'patient031_frame10', 'patient032_frame01', 'patient032_frame12',
'patient033_frame01', 'patient033_frame14', 'patient034_frame01',
'patient034_frame16', 'patient035_frame01', 'patient035_frame11',
'patient036_frame01', 'patient036_frame12', 'patient037_frame01',
'patient037_frame12', 'patient038_frame01', 'patient038_frame11',
'patient039_frame01', 'patient039_frame10', 'patient040_frame01',
'patient040_frame13', 'patient041_frame01', 'patient041_frame11',
'patient043_frame01', 'patient043_frame07', 'patient044_frame01',
'patient044_frame11', 'patient045_frame01', 'patient045_frame13',
'patient046_frame01', 'patient046_frame10', 'patient047_frame01',
'patient047_frame09', 'patient050_frame01', 'patient050_frame12',
'patient051_frame01', 'patient051_frame11', 'patient052_frame01',
'patient052_frame09', 'patient054_frame01', 'patient054_frame12',
'patient056_frame01', 'patient056_frame12', 'patient057_frame01',
'patient057_frame09', 'patient058_frame01', 'patient058_frame14',
'patient059_frame01', 'patient059_frame09', 'patient060_frame01',
'patient060_frame14', 'patient061_frame01', 'patient061_frame10',
'patient062_frame01', 'patient062_frame09', 'patient063_frame01',
'patient063_frame16', 'patient065_frame01', 'patient065_frame14',
'patient066_frame01', 'patient066_frame11', 'patient068_frame01',
'patient068_frame12', 'patient069_frame01', 'patient069_frame12',
'patient070_frame01', 'patient070_frame10', 'patient071_frame01',
'patient071_frame09', 'patient072_frame01', 'patient072_frame11',
'patient073_frame01', 'patient073_frame10', 'patient074_frame01',
'patient074_frame12', 'patient075_frame01', 'patient075_frame06',
'patient076_frame01', 'patient076_frame12', 'patient077_frame01',
'patient077_frame09', 'patient078_frame01', 'patient078_frame09',
'patient080_frame01', 'patient080_frame10', 'patient082_frame01',
'patient082_frame07', 'patient083_frame01', 'patient083_frame08',
'patient084_frame01', 'patient084_frame10', 'patient085_frame01',
'patient085_frame09', 'patient086_frame01', 'patient086_frame08',
'patient087_frame01', 'patient087_frame10', 'patient089_frame01',
'patient089_frame10', 'patient090_frame04', 'patient090_frame11',
'patient091_frame01', 'patient091_frame09', 'patient093_frame01',
'patient093_frame14', 'patient094_frame01', 'patient094_frame07',
'patient096_frame01', 'patient096_frame08', 'patient097_frame01',
'patient097_frame11', 'patient098_frame01', 'patient098_frame09',
'patient099_frame01', 'patient099_frame09', 'patient100_frame01',
'patient100_frame13'])
splits[self.fold]['val'] = np.array(['patient002_frame01', 'patient002_frame12', 'patient003_frame01','patient003_frame15', 'patient008_frame01', 'patient008_frame13', 'patient009_frame01', 'patient009_frame13', 'patient012_frame01', 'patient012_frame13', 'patient014_frame01', 'patient014_frame13', 'patient017_frame01', 'patient017_frame09', 'patient024_frame01', 'patient024_frame09', 'patient042_frame01', 'patient042_frame16', 'patient048_frame01', 'patient048_frame08', 'patient049_frame01', 'patient049_frame11', 'patient053_frame01', 'patient053_frame12', 'patient055_frame01', 'patient055_frame10', 'patient064_frame01', 'patient064_frame12', 'patient067_frame01', 'patient067_frame10', 'patient079_frame01', 'patient079_frame11', 'patient081_frame01', 'patient081_frame07', 'patient088_frame01', 'patient088_frame12', 'patient092_frame01', 'patient092_frame06', 'patient095_frame01', 'patient095_frame12'])
if self.fold < len(splits):
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
self.print_to_log_file("This split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
else:
self.print_to_log_file("INFO: You requested fold %d for training but splits "
"contain only %d folds. I am now creating a "
"random (but seeded) 80:20 split!" % (self.fold, len(splits)))
# if we request a fold that is not in the split file, create a random 80:20 split
rnd = np.random.RandomState(seed=12345 + self.fold)
keys = np.sort(list(self.dataset.keys()))
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
idx_val = [i for i in range(len(keys)) if i not in idx_tr]
tr_keys = [keys[i] for i in idx_tr]
val_keys = [keys[i] for i in idx_val]
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def setup_DA_params(self):
"""
- we increase roation angle from [-15, 15] to [-30, 30]
- scale range is now (0.7, 1.4), was (0.85, 1.25)
- we don't do elastic deformation anymore
:return:
"""
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params["scale_range"] = (0.7, 1.4)
self.data_aug_params["do_elastic"] = False
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
self.data_aug_params["num_cached_per_thread"] = 2
def maybe_update_lr(self, epoch=None):
"""
if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1
(maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
herefore we need to do +1 here)
:param epoch:
:return:
"""
if epoch is None:
ep = self.epoch + 1
else:
ep = epoch
self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
def on_epoch_end(self):
"""
overwrite patient-based early stopping. Always run to 1000 epochs
:return:
"""
super().on_epoch_end()
continue_training = self.epoch < self.max_num_epochs
# it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
# estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
if self.epoch == 100:
if self.all_val_eval_metrics[-1] == 0:
self.optimizer.param_groups[0]["momentum"] = 0.95
self.network.apply(InitWeights_He(1e-2))
self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
"high momentum. High momentum (0.99) is good for datasets where it works, but "
"sometimes causes issues such as this one. Momentum has now been reduced to "
"0.95 and network weights have been reinitialized")
return continue_training
def run_training(self):
"""
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
continued epoch with self.initial_lr
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
:return:
"""
self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
# want at the start of the training
ds = self.network.do_ds
if self.deep_supervision:
self.network.do_ds = True
else:
self.network.do_ds = False
ret = super().run_training()
self.network.do_ds = ds
return ret
================================================
FILE: unetr_pp/training/network_training/unetr_pp_trainer_lung.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Tuple
import numpy as np
import torch
from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
get_patch_size, default_3D_augmentation_params
from unetr_pp.training.dataloading.dataset_loading import unpack_dataset
from unetr_pp.training.network_training.Trainer_lung import Trainer_lung
from unetr_pp.utilities.nd_softmax import softmax_helper
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import autocast
from unetr_pp.training.learning_rate.poly_lr import poly_lr
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.network_architecture.lung.unetr_pp_lung import UNETR_PP
from fvcore.nn import FlopCountAnalysis
class unetr_pp_trainer_lung(Trainer_lung):
"""
Info for Fabian: same as internal nnUNetTrainerV2_2
"""
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
self.max_num_epochs = 1000
self.initial_lr = 1e-2
self.deep_supervision_scales = None
self.ds_loss_weights = None
self.pin_memory = True
self.load_pretrain_weight = False
self.load_plans_file()
if len(self.plans['plans_per_stage']) == 2:
Stage = 1
else:
Stage = 0
self.crop_size = (32,192,192) #self.plans['plans_per_stage'][Stage]['patch_size']
self.input_channels = self.plans['num_modalities']
self.num_classes = self.plans['num_classes'] + 1
#print("#############classes", self.num_classes)
self.conv_op = nn.Conv3d
self.embedding_dim = 96
self.depths = [2, 2, 2, 2]
self.num_heads = [3, 6, 12, 24]
self.embedding_patch_size = [1, 4, 4]
self.window_size = [[3, 5, 5], [3, 5, 5], [7, 10, 10], [3, 5, 5]]
self.down_stride = [[1, 4, 4], [1, 8, 8], [2, 16, 16], [4, 32, 32]]
self.deep_supervision = True
def initialize(self, training=True, force_load_plans=False):
"""
- replaced get_default_augmentation with get_moreDA_augmentation
- enforce to only run this code once
- loss function wrapper for deep supervision
:param training:
:param force_load_plans:
:return:
"""
if not self.was_initialized:
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.plans['plans_per_stage'][0]['patch_size'] = np.array([32, 192, 192])
self.crop_size = np.array([32, 192, 192])
self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[1, 4, 4], [2, 2, 2], [2, 2, 2]]
self.process_plans(self.plans)
self.setup_DA_params()
if self.deep_supervision:
################# Here we wrap the loss for deep supervision ############
# we need to know the number of outputs of the network
net_numpool = len(self.net_num_pool_op_kernel_sizes)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
# mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
# weights[~mask] = 0
weights = weights / weights.sum()
print(weights)
self.ds_loss_weights = weights
# now wrap the loss
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
################# END ###################
self.folder_with_preprocessed_data = join(self.dataset_directory,
self.plans['data_identifier'] + "_stage%d" % self.stage)
seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
if training:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_moreDA_augmentation(
self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params,
deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None,
pin_memory=self.pin_memory,
use_nondetMultiThreadedAugmenter=False,
seeds_train=seeds_train,
seeds_val=seeds_val
)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
else:
self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
self.was_initialized = True
def initialize_network(self):
"""
- momentum 0.99
- SGD instead of Adam
- self.lr_scheduler = None because we do poly_lr
- deep supervision = True
- i am sure I forgot something here
Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
:return:
"""
self.network = UNETR_PP(in_channels=self.input_channels,
out_channels=self.num_classes,
feature_size=16,
num_heads=4,
depths=[3, 3, 3, 3],
dims=[32, 64, 128, 256],
do_ds=True,
)
if torch.cuda.is_available():
self.network.cuda()
self.network.inference_apply_nonlin = softmax_helper
# Print the network parameters & Flops
n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
input_res = (1, 32, 192, 192)
input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype,
device=next(self.network.parameters()).device)
flops = FlopCountAnalysis(self.network, input)
model_flops = flops.total()
print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M")
print(f"MAdds: {round(model_flops * 1e-9, 2)} G")
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
self.lr_scheduler = None
def run_online_evaluation(self, output, target):
"""
due to deep supervision the return value and the reference are now lists of tensors. We only need the full
resolution output because this is what we are interested in in the end. The others are ignored
:param output:
:param target:
:return:
"""
if self.deep_supervision:
target = target[0]
output = output[0]
else:
target = target
output = output
return super().run_online_evaluation(output, target)
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
save_softmax=save_softmax, use_gaussian=use_gaussian,
overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
run_postprocessing_on_folds=run_postprocessing_on_folds)
self.network.do_ds = ds
return ret
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision=True) -> Tuple[
np.ndarray, np.ndarray]:
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().predict_preprocessed_data_return_seg_and_softmax(data,
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size, use_gaussian=use_gaussian,
pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose,
mixed_precision=mixed_precision)
self.network.do_ds = ds
return ret
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
"""
gradient clipping improves training stability
:param data_generator:
:param do_backprop:
:param run_online_evaluation:
:return:
"""
data_dict = next(data_generator)
data = data_dict['data']
#print("****data:",data.shape)
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def do_split(self):
"""
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
Sometimes you may want to create your own split for various reasons. For this you will need to create your own
splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
use a random 80:20 data split.
:return:
"""
if self.fold == "all":
# if fold==all then we use all images for training and validation
tr_keys = val_keys = list(self.dataset.keys())
else:
splits_file = join(self.dataset_directory, "splits_final.pkl")
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
else:
self.print_to_log_file("Using splits from existing split file:", splits_file)
splits = load_pickle(splits_file)
self.print_to_log_file("The split file contains %d splits." % len(splits))
self.print_to_log_file("Desired fold for training: %d" % self.fold)
splits[self.fold]['train'] = np.array([ 'lung_053', 'lung_022', 'lung_041', 'lung_069', 'lung_014',
'lung_006', 'lung_065', 'lung_018', 'lung_096', 'lung_084',
'lung_086', 'lung_043', 'lung_020', 'lung_051', 'lung_079',
'lung_004', 'lung_075', 'lung_016', 'lung_071', 'lung_028',
'lung_055', 'lung_036', 'lung_047', 'lung_059', 'lung_061',
'lung_010', 'lung_073', 'lung_026', 'lung_038', 'lung_045',
'lung_034', 'lung_049', 'lung_057', 'lung_080', 'lung_092',
'lung_015', 'lung_064', 'lung_031', 'lung_023', 'lung_005',
'lung_078', 'lung_066', 'lung_009', 'lung_074', 'lung_042',
'lung_033', 'lung_095', 'lung_037', 'lung_054', 'lung_029'
])
splits[self.fold]['val'] = np.array([ 'lung_058', 'lung_025', 'lung_046', 'lung_070', 'lung_001',
'lung_062', 'lung_083', 'lung_081', 'lung_093', 'lung_044',
'lung_027', 'lung_048', 'lung_003' ])
if self.fold < len(splits):
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
self.print_to_log_file("This split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
else:
self.print_to_log_file("INFO: You requested fold %d for training but splits "
"contain only %d folds. I am now creating a "
"random (but seeded) 80:20 split!" % (self.fold, len(splits)))
# if we request a fold that is not in the split file, create a random 80:20 split
rnd = np.random.RandomState(seed=12345 + self.fold)
keys = np.sort(list(self.dataset.keys()))
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
idx_val = [i for i in range(len(keys)) if i not in idx_tr]
tr_keys = [keys[i] for i in idx_tr]
val_keys = [keys[i] for i in idx_val]
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def setup_DA_params(self):
"""
- we increase roation angle from [-15, 15] to [-30, 30]
- scale range is now (0.7, 1.4), was (0.85, 1.25)
- we don't do elastic deformation anymore
:return:
"""
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params["scale_range"] = (0.7, 1.4)
self.data_aug_params["do_elastic"] = False
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
self.data_aug_params["num_cached_per_thread"] = 2
def maybe_update_lr(self, epoch=None):
"""
if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1
(maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
herefore we need to do +1 here)
:param epoch:
:return:
"""
if epoch is None:
ep = self.epoch + 1
else:
ep = epoch
self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
def on_epoch_end(self):
"""
overwrite patient-based early stopping. Always run to 1000 epochs
:return:
"""
super().on_epoch_end()
continue_training = self.epoch < self.max_num_epochs
# it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
# estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
if self.epoch == 100:
if self.all_val_eval_metrics[-1] == 0:
self.optimizer.param_groups[0]["momentum"] = 0.95
self.network.apply(InitWeights_He(1e-2))
self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
"high momentum. High momentum (0.99) is good for datasets where it works, but "
"sometimes causes issues such as this one. Momentum has now been reduced to "
"0.95 and network weights have been reinitialized")
return continue_training
def run_training(self):
"""
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
continued epoch with self.initial_lr
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
:return:
"""
self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
# want at the start of the training
ds = self.network.do_ds
if self.deep_supervision:
self.network.do_ds = True
else:
self.network.do_ds = False
ret = super().run_training()
self.network.do_ds = ds
return ret
================================================
FILE: unetr_pp/training/network_training/unetr_pp_trainer_synapse.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Tuple
import numpy as np
import torch
from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
get_patch_size, default_3D_augmentation_params
from unetr_pp.training.dataloading.dataset_loading import unpack_dataset
from unetr_pp.training.network_training.Trainer_synapse import Trainer_synapse
from unetr_pp.utilities.nd_softmax import softmax_helper
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import autocast
from unetr_pp.training.learning_rate.poly_lr import poly_lr
from batchgenerators.utilities.file_and_folder_operations import *
from fvcore.nn import FlopCountAnalysis
class unetr_pp_trainer_synapse(Trainer_synapse):
"""
same as internal nnFromerTrinerV2 and nnUNetTrainerV2_2
"""
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
self.max_num_epochs = 1000
self.initial_lr = 1e-2
self.deep_supervision_scales = None
self.ds_loss_weights = None
self.pin_memory = True
self.load_pretrain_weight = False
self.load_plans_file()
self.crop_size = [64, 128, 128]
self.input_channels = self.plans['num_modalities']
self.num_classes = self.plans['num_classes'] + 1
self.conv_op = nn.Conv3d
self.embedding_dim = 192
self.depths = [2, 2, 2, 2]
self.num_heads = [6, 12, 24, 48]
self.embedding_patch_size = [2, 4, 4]
self.window_size = [4, 4, 8, 4]
self.deep_supervision = True
def initialize(self, training=True, force_load_plans=False):
"""
- replaced get_default_augmentation with get_moreDA_augmentation
- enforce to only run this code once
- loss function wrapper for deep supervision
:param training:
:param force_load_plans:
:return:
"""
if not self.was_initialized:
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[2, 4, 4], [2, 2, 2], [2, 2, 2]]
self.process_plans(self.plans)
self.setup_DA_params()
if self.deep_supervision:
################# Here we wrap the loss for deep supervision ############
# we need to know the number of outputs of the network
net_numpool = len(self.net_num_pool_op_kernel_sizes)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
# mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
# weights[~mask] = 0
weights = weights / weights.sum()
print(weights)
self.ds_loss_weights = weights
# now wrap the loss
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
################# END ###################
self.folder_with_preprocessed_data = join(self.dataset_directory,
self.plans['data_identifier'] + "_stage%d" % self.stage)
seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
if training:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_moreDA_augmentation(
self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params,
deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None,
pin_memory=self.pin_memory,
use_nondetMultiThreadedAugmenter=False,
seeds_train=seeds_train,
seeds_val=seeds_val
)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
else:
self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
self.was_initialized = True
def initialize_network(self):
"""
- momentum 0.99
- SGD instead of Adam
- self.lr_scheduler = None because we do poly_lr
- deep supervision = True
- i am sure I forgot something here
Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
:return:
"""
self.network = UNETR_PP(in_channels=self.input_channels,
out_channels=self.num_classes,
img_size=self.crop_size,
feature_size=16,
num_heads=4,
depths=[3, 3, 3, 3],
dims=[32, 64, 128, 256],
do_ds=True,
)
if torch.cuda.is_available():
self.network.cuda()
self.network.inference_apply_nonlin = softmax_helper
# Print the network parameters & Flops
n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
input_res = (1, 64, 128, 128)
input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype,
device=next(self.network.parameters()).device)
flops = FlopCountAnalysis(self.network, input)
model_flops = flops.total()
print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M")
print(f"MAdds: {round(model_flops * 1e-9, 2)} G")
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
self.lr_scheduler = None
def run_online_evaluation(self, output, target):
"""
due to deep supervision the return value and the reference are now lists of tensors. We only need the full
resolution output because this is what we are interested in in the end. The others are ignored
:param output:
:param target:
:return:
"""
if self.deep_supervision:
target = target[0]
output = output[0]
else:
target = target
output = output
return super().run_online_evaluation(output, target)
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
save_softmax=save_softmax, use_gaussian=use_gaussian,
overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
run_postprocessing_on_folds=run_postprocessing_on_folds)
self.network.do_ds = ds
return ret
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision=True) -> Tuple[
np.ndarray, np.ndarray]:
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().predict_preprocessed_data_return_seg_and_softmax(data,
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size, use_gaussian=use_gaussian,
pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose,
mixed_precision=mixed_precision)
self.network.do_ds = ds
return ret
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
"""
gradient clipping improves training stability
:param data_generator:
:param do_backprop:
:param run_online_evaluation:
:return:
"""
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def do_split(self):
"""
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
Sometimes you may want to create your own split for various reasons. For this you will need to create your own
splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
use a random 80:20 data split.
:return:
"""
if self.fold == "all":
# if fold==all then we use all images for training and validation
tr_keys = val_keys = list(self.dataset.keys())
else:
splits_file = join(self.dataset_directory, "splits_final.pkl")
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
else:
self.print_to_log_file("Using splits from existing split file:", splits_file)
splits = load_pickle(splits_file)
self.print_to_log_file("The split file contains %d splits." % len(splits))
self.print_to_log_file("Desired fold for training: %d" % self.fold)
splits[self.fold]['train'] = np.array(
['img0006', 'img0007', 'img0009', 'img0010', 'img0021', 'img0023', 'img0024', 'img0026', 'img0027',
'img0031', 'img0033', 'img0034' \
, 'img0039', 'img0040', 'img0005', 'img0028', 'img0030', 'img0037'])
splits[self.fold]['val'] = np.array(
['img0001', 'img0002', 'img0003', 'img0004', 'img0008', 'img0022', 'img0025', 'img0029', 'img0032',
'img0035', 'img0036', 'img0038'])
if self.fold < len(splits):
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
self.print_to_log_file("This split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
else:
self.print_to_log_file("INFO: You requested fold %d for training but splits "
"contain only %d folds. I am now creating a "
"random (but seeded) 80:20 split!" % (self.fold, len(splits)))
# if we request a fold that is not in the split file, create a random 80:20 split
rnd = np.random.RandomState(seed=12345 + self.fold)
keys = np.sort(list(self.dataset.keys()))
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
idx_val = [i for i in range(len(keys)) if i not in idx_tr]
tr_keys = [keys[i] for i in idx_tr]
val_keys = [keys[i] for i in idx_val]
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def setup_DA_params(self):
"""
- we increase roation angle from [-15, 15] to [-30, 30]
- scale range is now (0.7, 1.4), was (0.85, 1.25)
- we don't do elastic deformation anymore
:return:
"""
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params["scale_range"] = (0.7, 1.4)
self.data_aug_params["do_elastic"] = False
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
self.data_aug_params["num_cached_per_thread"] = 2
def maybe_update_lr(self, epoch=None):
"""
if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1
(maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
herefore we need to do +1 here)
:param epoch:
:return:
"""
if epoch is None:
ep = self.epoch + 1
else:
ep = epoch
self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
def on_epoch_end(self):
"""
overwrite patient-based early stopping. Always run to 1000 epochs
:return:
"""
super().on_epoch_end()
continue_training = self.epoch < self.max_num_epochs
# it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
# estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
if self.epoch == 100:
if self.all_val_eval_metrics[-1] == 0:
self.optimizer.param_groups[0]["momentum"] = 0.95
self.network.apply(InitWeights_He(1e-2))
self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
"high momentum. High momentum (0.99) is good for datasets where it works, but "
"sometimes causes issues such as this one. Momentum has now been reduced to "
"0.95 and network weights have been reinitialized")
return continue_training
def run_training(self):
"""
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
continued epoch with self.initial_lr
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
:return:
"""
self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
# want at the start of the training
ds = self.network.do_ds
if self.deep_supervision:
self.network.do_ds = True
else:
self.network.do_ds = False
ret = super().run_training()
self.network.do_ds = ds
return ret
================================================
FILE: unetr_pp/training/network_training/unetr_pp_trainer_tumor.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Tuple
import numpy as np
import torch
from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2
from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda
from unetr_pp.network_architecture.tumor.unetr_pp_tumor import UNETR_PP
from unetr_pp.network_architecture.initialization import InitWeights_He
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
get_patch_size, default_3D_augmentation_params
from unetr_pp.training.dataloading.dataset_loading import unpack_dataset
from unetr_pp.training.network_training.Trainer_tumor import Trainer_tumor
from unetr_pp.utilities.nd_softmax import softmax_helper
from sklearn.model_selection import KFold
from torch import nn
from torch.cuda.amp import autocast
from unetr_pp.training.learning_rate.poly_lr import poly_lr
from batchgenerators.utilities.file_and_folder_operations import *
from fvcore.nn import FlopCountAnalysis
class unetr_pp_trainer_tumor(Trainer_tumor):
"""
same as internal nnFromerTrinerV2 and nnUNetTrainerV2_2
"""
def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
unpack_data=True, deterministic=True, fp16=False):
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
deterministic, fp16)
self.max_num_epochs = 1000
self.initial_lr = 1e-2
self.deep_supervision_scales = None
self.ds_loss_weights = None
self.pin_memory = True
self.load_pretrain_weight = False
self.load_plans_file()
if len(self.plans['plans_per_stage'])==2:
Stage=1
else:
Stage=0
self.crop_size = self.plans['plans_per_stage'][Stage]['patch_size']
print("self.crop_size in init train:",self.crop_size)
self.input_channels = self.plans['num_modalities']
#print("SELF.INPUT CHANNELS",self.input_channels)
self.num_classes = self.plans['num_classes'] + 1
#print("SELF.NUM_CLASSES",self.num_classes)
self.conv_op = nn.Conv3d
self.embedding_dim = 96
self.depths = [2, 2, 2, 2]
self.num_heads = [3, 6, 12, 24]
self.embedding_patch_size = [4, 4, 4]
self.window_size = [4, 4, 8, 4]
self.deep_supervision = True
def initialize(self, training=True, force_load_plans=False):
"""
- replaced get_default_augmentation with get_moreDA_augmentation
- enforce to only run this code once
- loss function wrapper for deep supervision
:param training:
:param force_load_plans:
:return:
"""
if not self.was_initialized:
maybe_mkdir_p(self.output_folder)
if force_load_plans or (self.plans is None):
self.load_plans_file()
self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[4, 4, 4], [2, 2, 2], [2, 2, 2]]
self.process_plans(self.plans)
self.setup_DA_params()
if self.deep_supervision:
################# Here we wrap the loss for deep supervision ############
# we need to know the number of outputs of the network
net_numpool = len(self.net_num_pool_op_kernel_sizes)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
# mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
# weights[~mask] = 0
weights = weights / weights.sum()
print(weights)
self.ds_loss_weights = weights
# now wrap the loss
self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
################# END ###################
self.folder_with_preprocessed_data = join(self.dataset_directory,
self.plans['data_identifier'] + "_stage%d" % self.stage)
seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
if training:
self.dl_tr, self.dl_val = self.get_basic_generators()
if self.unpack_data:
print("unpacking dataset")
unpack_dataset(self.folder_with_preprocessed_data)
print("done")
else:
print(
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
"will wait all winter for your model to finish!")
self.tr_gen, self.val_gen = get_moreDA_augmentation(
self.dl_tr, self.dl_val,
self.data_aug_params[
'patch_size_for_spatialtransform'],
self.data_aug_params,
deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None,
pin_memory=self.pin_memory,
use_nondetMultiThreadedAugmenter=False,
seeds_train=seeds_train,
seeds_val=seeds_val
)
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
also_print_to_console=False)
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
also_print_to_console=False)
else:
pass
self.initialize_network()
self.initialize_optimizer_and_scheduler()
assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
else:
self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
self.was_initialized = True
def initialize_network(self):
"""
- momentum 0.99
- SGD instead of Adam
- self.lr_scheduler = None because we do poly_lr
- deep supervision = True
- i am sure I forgot something here
Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
:return:
"""
self.network = UNETR_PP(in_channels=self.input_channels,
out_channels=self.num_classes,
feature_size=16,
num_heads=4,
depths=[3, 3, 3, 3],
dims=[32, 64, 128, 256],
do_ds=True,
)
#print("self.input_channels", self.input_channels)
#print("self.num_classes", self.num_classes)
if torch.cuda.is_available():
self.network.cuda()
self.network.inference_apply_nonlin = softmax_helper
# Print the network parameters & Flops
n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
input_res = (4, 128, 128, 128)
input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype,
device=next(self.network.parameters()).device)
flops = FlopCountAnalysis(self.network, input)
model_flops = flops.total()
print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M")
print(f"MAdds: {round(model_flops * 1e-9, 2)} G")
def initialize_optimizer_and_scheduler(self):
assert self.network is not None, "self.initialize_network must be called first"
self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
self.lr_scheduler = None
def run_online_evaluation(self, output, target):
"""
due to deep supervision the return value and the reference are now lists of tensors. We only need the full
resolution output because this is what we are interested in in the end. The others are ignored
:param output:
:param target:
:return:
"""
if self.deep_supervision:
target = target[0]
output = output[0]
else:
target = target
output = output
return super().run_online_evaluation(output, target)
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
save_softmax=save_softmax, use_gaussian=use_gaussian,
overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
run_postprocessing_on_folds=run_postprocessing_on_folds)
self.network.do_ds = ds
return ret
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
mirror_axes: Tuple[int] = None,
use_sliding_window: bool = True, step_size: float = 0.5,
use_gaussian: bool = True, pad_border_mode: str = 'constant',
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision=True) -> Tuple[
np.ndarray, np.ndarray]:
"""
We need to wrap this because we need to enforce self.network.do_ds = False for prediction
"""
ds = self.network.do_ds
self.network.do_ds = False
ret = super().predict_preprocessed_data_return_seg_and_softmax(data,
do_mirroring=do_mirroring,
mirror_axes=mirror_axes,
use_sliding_window=use_sliding_window,
step_size=step_size, use_gaussian=use_gaussian,
pad_border_mode=pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose,
mixed_precision=mixed_precision)
self.network.do_ds = ds
return ret
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
"""
gradient clipping improves training stability
:param data_generator:
:param do_backprop:
:param run_online_evaluation:
:return:
"""
data_dict = next(data_generator)
data = data_dict['data']
target = data_dict['target']
data = maybe_to_torch(data)
target = maybe_to_torch(target)
if torch.cuda.is_available():
data = to_cuda(data)
target = to_cuda(target)
self.optimizer.zero_grad()
if self.fp16:
with autocast():
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
self.amp_grad_scaler.scale(l).backward()
self.amp_grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.amp_grad_scaler.step(self.optimizer)
self.amp_grad_scaler.update()
else:
output = self.network(data)
del data
l = self.loss(output, target)
if do_backprop:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
if run_online_evaluation:
self.run_online_evaluation(output, target)
del target
return l.detach().cpu().numpy()
def do_split(self):
"""
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
Sometimes you may want to create your own split for various reasons. For this you will need to create your own
splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
use a random 80:20 data split.
:return:
"""
if self.fold == "all":
# if fold==all then we use all images for training and validation
tr_keys = val_keys = list(self.dataset.keys())
else:
splits_file = join(self.dataset_directory, "splits_final.pkl")
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
splits = []
all_keys_sorted = np.sort(list(self.dataset.keys()))
kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
train_keys = np.array(all_keys_sorted)[train_idx]
test_keys = np.array(all_keys_sorted)[test_idx]
splits.append(OrderedDict())
splits[-1]['train'] = train_keys
splits[-1]['val'] = test_keys
save_pickle(splits, splits_file)
else:
self.print_to_log_file("Using splits from existing split file:", splits_file)
splits = load_pickle(splits_file)
self.print_to_log_file("The split file contains %d splits." % len(splits))
self.print_to_log_file("Desired fold for training: %d" % self.fold)
splits[self.fold]['train'] = np.array([
'BRATS_001', 'BRATS_002', 'BRATS_003', 'BRATS_004', 'BRATS_005',
'BRATS_006', 'BRATS_007', 'BRATS_008', 'BRATS_009', 'BRATS_010',
'BRATS_013', 'BRATS_014', 'BRATS_015', 'BRATS_016', 'BRATS_017',
'BRATS_019', 'BRATS_022', 'BRATS_023', 'BRATS_024', 'BRATS_025',
'BRATS_026', 'BRATS_027', 'BRATS_030', 'BRATS_031', 'BRATS_033',
'BRATS_035', 'BRATS_037', 'BRATS_038', 'BRATS_039', 'BRATS_040',
'BRATS_042', 'BRATS_043', 'BRATS_044', 'BRATS_045', 'BRATS_046',
'BRATS_048', 'BRATS_050', 'BRATS_051', 'BRATS_052', 'BRATS_054',
'BRATS_055', 'BRATS_060', 'BRATS_061', 'BRATS_062', 'BRATS_063',
'BRATS_064', 'BRATS_065', 'BRATS_066', 'BRATS_067', 'BRATS_068',
'BRATS_070', 'BRATS_072', 'BRATS_073', 'BRATS_074', 'BRATS_075',
'BRATS_078', 'BRATS_079', 'BRATS_080', 'BRATS_081', 'BRATS_082',
'BRATS_083', 'BRATS_084', 'BRATS_085', 'BRATS_086', 'BRATS_087',
'BRATS_088', 'BRATS_091', 'BRATS_093', 'BRATS_094', 'BRATS_096',
'BRATS_097', 'BRATS_098', 'BRATS_100', 'BRATS_101', 'BRATS_102',
'BRATS_104', 'BRATS_108', 'BRATS_110', 'BRATS_111', 'BRATS_112',
'BRATS_115', 'BRATS_116', 'BRATS_117', 'BRATS_119', 'BRATS_120',
'BRATS_121', 'BRATS_122', 'BRATS_123', 'BRATS_125', 'BRATS_126',
'BRATS_127', 'BRATS_128', 'BRATS_129', 'BRATS_130', 'BRATS_131',
'BRATS_132', 'BRATS_133', 'BRATS_134', 'BRATS_135', 'BRATS_136',
'BRATS_137', 'BRATS_138', 'BRATS_140', 'BRATS_141', 'BRATS_142',
'BRATS_143', 'BRATS_144', 'BRATS_146', 'BRATS_148', 'BRATS_149',
'BRATS_150', 'BRATS_153', 'BRATS_154', 'BRATS_155', 'BRATS_158',
'BRATS_159', 'BRATS_160', 'BRATS_162', 'BRATS_163', 'BRATS_164',
'BRATS_165', 'BRATS_166', 'BRATS_167', 'BRATS_168', 'BRATS_169',
'BRATS_170', 'BRATS_171', 'BRATS_173', 'BRATS_174', 'BRATS_175',
'BRATS_177', 'BRATS_178', 'BRATS_179', 'BRATS_180', 'BRATS_182',
'BRATS_183', 'BRATS_184', 'BRATS_185', 'BRATS_186', 'BRATS_187',
'BRATS_188', 'BRATS_189', 'BRATS_191', 'BRATS_192', 'BRATS_193',
'BRATS_195', 'BRATS_197', 'BRATS_199', 'BRATS_200', 'BRATS_201',
'BRATS_202', 'BRATS_203', 'BRATS_206', 'BRATS_207', 'BRATS_208',
'BRATS_210', 'BRATS_211', 'BRATS_212', 'BRATS_213', 'BRATS_214',
'BRATS_215', 'BRATS_216', 'BRATS_217', 'BRATS_218', 'BRATS_219',
'BRATS_222', 'BRATS_223', 'BRATS_224', 'BRATS_225', 'BRATS_226',
'BRATS_228', 'BRATS_229', 'BRATS_230', 'BRATS_231', 'BRATS_232',
'BRATS_233', 'BRATS_236', 'BRATS_237', 'BRATS_238', 'BRATS_239',
'BRATS_241', 'BRATS_243', 'BRATS_244', 'BRATS_246', 'BRATS_247',
'BRATS_248', 'BRATS_249', 'BRATS_251', 'BRATS_252', 'BRATS_253',
'BRATS_254', 'BRATS_255', 'BRATS_258', 'BRATS_259', 'BRATS_261',
'BRATS_262', 'BRATS_263', 'BRATS_264', 'BRATS_265', 'BRATS_266',
'BRATS_267', 'BRATS_268', 'BRATS_272', 'BRATS_273', 'BRATS_274',
'BRATS_275', 'BRATS_276', 'BRATS_277', 'BRATS_278', 'BRATS_279',
'BRATS_280', 'BRATS_283', 'BRATS_284', 'BRATS_285', 'BRATS_286',
'BRATS_288', 'BRATS_290', 'BRATS_293', 'BRATS_294', 'BRATS_296',
'BRATS_297', 'BRATS_298', 'BRATS_299', 'BRATS_300', 'BRATS_301',
'BRATS_302', 'BRATS_303', 'BRATS_304', 'BRATS_306', 'BRATS_307',
'BRATS_308', 'BRATS_309', 'BRATS_311', 'BRATS_312', 'BRATS_313',
'BRATS_315', 'BRATS_316', 'BRATS_317', 'BRATS_318', 'BRATS_319',
'BRATS_320', 'BRATS_321', 'BRATS_322', 'BRATS_324', 'BRATS_326',
'BRATS_328', 'BRATS_329', 'BRATS_332', 'BRATS_334', 'BRATS_335',
'BRATS_336', 'BRATS_338', 'BRATS_339', 'BRATS_340', 'BRATS_341',
'BRATS_342', 'BRATS_343', 'BRATS_344', 'BRATS_345', 'BRATS_347',
'BRATS_348', 'BRATS_349', 'BRATS_351', 'BRATS_353', 'BRATS_354',
'BRATS_355', 'BRATS_356', 'BRATS_357', 'BRATS_358', 'BRATS_359',
'BRATS_360', 'BRATS_363', 'BRATS_364', 'BRATS_365', 'BRATS_366',
'BRATS_367', 'BRATS_368', 'BRATS_369', 'BRATS_370', 'BRATS_371',
'BRATS_372', 'BRATS_373', 'BRATS_374', 'BRATS_375', 'BRATS_376',
'BRATS_377', 'BRATS_378', 'BRATS_379', 'BRATS_380', 'BRATS_381',
'BRATS_383', 'BRATS_384', 'BRATS_385', 'BRATS_386', 'BRATS_387',
'BRATS_388', 'BRATS_390', 'BRATS_391', 'BRATS_392', 'BRATS_393',
'BRATS_394', 'BRATS_395', 'BRATS_396', 'BRATS_398', 'BRATS_399',
'BRATS_401', 'BRATS_403', 'BRATS_404', 'BRATS_405', 'BRATS_407',
'BRATS_408', 'BRATS_409', 'BRATS_410', 'BRATS_411', 'BRATS_412',
'BRATS_413', 'BRATS_414', 'BRATS_415', 'BRATS_417', 'BRATS_418',
'BRATS_419', 'BRATS_420', 'BRATS_421', 'BRATS_422', 'BRATS_423',
'BRATS_424', 'BRATS_426', 'BRATS_428', 'BRATS_429', 'BRATS_430',
'BRATS_431', 'BRATS_433', 'BRATS_434', 'BRATS_435', 'BRATS_436',
'BRATS_437', 'BRATS_438', 'BRATS_439', 'BRATS_441', 'BRATS_442',
'BRATS_443', 'BRATS_444', 'BRATS_445', 'BRATS_446', 'BRATS_449',
'BRATS_451', 'BRATS_452', 'BRATS_453', 'BRATS_454', 'BRATS_455',
'BRATS_457', 'BRATS_458', 'BRATS_459', 'BRATS_460', 'BRATS_463',
'BRATS_464', 'BRATS_466', 'BRATS_467', 'BRATS_468', 'BRATS_469',
'BRATS_470', 'BRATS_472', 'BRATS_475', 'BRATS_477', 'BRATS_478',
'BRATS_481', 'BRATS_482', 'BRATS_483','BRATS_400', 'BRATS_402',
'BRATS_406', 'BRATS_416', 'BRATS_427', 'BRATS_440', 'BRATS_447',
'BRATS_448', 'BRATS_456', 'BRATS_461', 'BRATS_462', 'BRATS_465',
'BRATS_471', 'BRATS_473', 'BRATS_474', 'BRATS_476', 'BRATS_479',
'BRATS_480', 'BRATS_484'])
splits[self.fold]['val'] = np.array(['BRATS_011', 'BRATS_012', 'BRATS_018', 'BRATS_020', 'BRATS_021',
'BRATS_028', 'BRATS_029', 'BRATS_032', 'BRATS_034', 'BRATS_036',
'BRATS_041', 'BRATS_047', 'BRATS_049', 'BRATS_053', 'BRATS_056',
'BRATS_057', 'BRATS_069', 'BRATS_071', 'BRATS_089', 'BRATS_090',
'BRATS_092', 'BRATS_095', 'BRATS_103', 'BRATS_105', 'BRATS_106',
'BRATS_107', 'BRATS_109', 'BRATS_118', 'BRATS_145', 'BRATS_147',
'BRATS_156', 'BRATS_161', 'BRATS_172', 'BRATS_176', 'BRATS_181',
'BRATS_194', 'BRATS_196', 'BRATS_198', 'BRATS_204', 'BRATS_205',
'BRATS_209', 'BRATS_220', 'BRATS_221', 'BRATS_227', 'BRATS_234',
'BRATS_235', 'BRATS_245', 'BRATS_250', 'BRATS_256', 'BRATS_257',
'BRATS_260', 'BRATS_269', 'BRATS_270', 'BRATS_271', 'BRATS_281',
'BRATS_282', 'BRATS_287', 'BRATS_289', 'BRATS_291', 'BRATS_292',
'BRATS_310', 'BRATS_314', 'BRATS_323', 'BRATS_327', 'BRATS_330',
'BRATS_333', 'BRATS_337', 'BRATS_346', 'BRATS_350', 'BRATS_352',
'BRATS_361', 'BRATS_382', 'BRATS_397'])
if self.fold < len(splits):
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
self.print_to_log_file("This split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
else:
self.print_to_log_file("INFO: You requested fold %d for training but splits "
"contain only %d folds. I am now creating a "
"random (but seeded) 80:20 split!" % (self.fold, len(splits)))
# if we request a fold that is not in the split file, create a random 80:20 split
rnd = np.random.RandomState(seed=12345 + self.fold)
keys = np.sort(list(self.dataset.keys()))
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
idx_val = [i for i in range(len(keys)) if i not in idx_tr]
tr_keys = [keys[i] for i in idx_tr]
val_keys = [keys[i] for i in idx_val]
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
tr_keys.sort()
val_keys.sort()
self.dataset_tr = OrderedDict()
for i in tr_keys:
self.dataset_tr[i] = self.dataset[i]
self.dataset_val = OrderedDict()
for i in val_keys:
self.dataset_val[i] = self.dataset[i]
def setup_DA_params(self):
"""
- we increase roation angle from [-15, 15] to [-30, 30]
- scale range is now (0.7, 1.4), was (0.85, 1.25)
- we don't do elastic deformation anymore
:return:
"""
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
if self.threeD:
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
if self.do_dummy_2D_aug:
self.data_aug_params["dummy_2D"] = True
self.print_to_log_file("Using dummy2d data augmentation")
self.data_aug_params["elastic_deform_alpha"] = \
default_2D_augmentation_params["elastic_deform_alpha"]
self.data_aug_params["elastic_deform_sigma"] = \
default_2D_augmentation_params["elastic_deform_sigma"]
self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
else:
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
if self.do_dummy_2D_aug:
self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
patch_size_for_spatialtransform = self.patch_size[1:]
else:
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
patch_size_for_spatialtransform = self.patch_size
self.data_aug_params["scale_range"] = (0.7, 1.4)
self.data_aug_params["do_elastic"] = False
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform
self.data_aug_params["num_cached_per_thread"] = 2
def maybe_update_lr(self, epoch=None):
"""
if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1
(maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
herefore we need to do +1 here)
:param epoch:
:return:
"""
if epoch is None:
ep = self.epoch + 1
else:
ep = epoch
self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
def on_epoch_end(self):
"""
overwrite patient-based early stopping. Always run to 1000 epochs
:return:
"""
super().on_epoch_end()
continue_training = self.epoch < self.max_num_epochs
# it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
# estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
if self.epoch == 100:
if self.all_val_eval_metrics[-1] == 0:
self.optimizer.param_groups[0]["momentum"] = 0.95
self.network.apply(InitWeights_He(1e-2))
self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
"high momentum. High momentum (0.99) is good for datasets where it works, but "
"sometimes causes issues such as this one. Momentum has now been reduced to "
"0.95 and network weights have been reinitialized")
return continue_training
def run_training(self):
"""
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
continued epoch with self.initial_lr
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
:return:
"""
self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
# want at the start of the training
ds = self.network.do_ds
if self.deep_supervision:
self.network.do_ds = True
else:
self.network.do_ds = False
ret = super().run_training()
self.network.do_ds = ds
return ret
================================================
FILE: unetr_pp/training/optimizer/ranger.py
================================================
############
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
# This code was taken from the repo above and was not created by me (Fabian)! Full credit goes to the original authors
############
import math
import torch
from torch.optim.optimizer import Optimizer
class Ranger(Optimizer):
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5,
weight_decay=0):
# parameter checks
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
if not lr > 0:
raise ValueError(f'Invalid Learning Rate: {lr}')
if not eps > 0:
raise ValueError(f'Invalid eps: {eps}')
# parameter comments:
# beta1 (momentum) of .95 seems to work better than .90...
# N_sma_threshold of 5 seems better in testing than 4.
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
# prep defaults and init torch.optim base
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
# adjustable threshold
self.N_sma_threshhold = N_sma_threshhold
# now we can get to work...
# removed as we now use step from RAdam...no need for duplicate step counting
# for group in self.param_groups:
# group["step_counter"] = 0
# print("group step counter init")
# look ahead params
self.alpha = alpha
self.k = k
# radam buffer for state
self.radam_buffer = [[None, None, None] for ind in range(10)]
# self.first_run_check=0
# lookahead weights
# 9/2/19 - lookahead param tensors have been moved to state storage.
# This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
# self.slow_weights = [[p.clone().detach() for p in group['params']]
# for group in self.param_groups]
# don't use grad for lookahead weights
# for w in it.chain(*self.slow_weights):
# w.requires_grad = False
def __setstate__(self, state):
print("set state called")
super(Ranger, self).__setstate__(state)
def step(self, closure=None):
loss = None
# note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
# Uncomment if you need to use the actual closure...
# if closure is not None:
# loss = closure()
# Evaluate averages and grad, update param tensors
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ranger optimizer does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p] # get state dict for this param
if len(state) == 0: # if first time to run...init dictionary with our desired entries
# if self.first_run_check==0:
# self.first_run_check=1
# print("Initializing slow buffer...should not see this at load from saved model!")
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
# look ahead weight storage now in state dict
state['slow_buffer'] = torch.empty_like(p.data)
state['slow_buffer'].copy_(p.data)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
# begin computations
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# compute variance mov avg
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# compute mean moving avg
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = self.radam_buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
if N_sma > self.N_sma_threshhold:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
else:
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)
# integrated look ahead...
# we do it at the param level instead of group level
if state['step'] % group['k'] == 0:
slow_p = state['slow_buffer'] # get access to slow param tensor
slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
return loss
================================================
FILE: unetr_pp/utilities/__init__.py
================================================
from __future__ import absolute_import
from . import *
================================================
FILE: unetr_pp/utilities/distributed.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import distributed
from torch import autograd
from torch.nn.parallel import DistributedDataParallel as DDP
def print_if_rank0(*args):
if distributed.get_rank() == 0:
print(*args)
class awesome_allgather_function(autograd.Function):
@staticmethod
def forward(ctx, input):
world_size = distributed.get_world_size()
# create a destination list for the allgather. I'm assuming you're gathering from 3 workers.
allgather_list = [torch.empty_like(input) for _ in range(world_size)]
#if distributed.get_rank() == 0:
# import IPython;IPython.embed()
distributed.all_gather(allgather_list, input)
return torch.cat(allgather_list, dim=0)
@staticmethod
def backward(ctx, grad_output):
#print_if_rank0("backward grad_output len", len(grad_output))
#print_if_rank0("backward grad_output shape", grad_output.shape)
grads_per_rank = grad_output.shape[0] // distributed.get_world_size()
rank = distributed.get_rank()
# We'll receive gradients for the entire catted forward output, so to mimic DataParallel,
# return only the slice that corresponds to this process's input:
sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank)
#print("worker", rank, "backward slice", sl)
return grad_output[sl]
if __name__ == "__main__":
import torch.distributed as dist
import argparse
from torch import nn
from torch.optim import Adam
argumentparser = argparse.ArgumentParser()
argumentparser.add_argument("--local_rank", type=int)
args = argumentparser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
rnd = torch.rand((5, 2)).cuda()
rnd_gathered = awesome_allgather_function.apply(rnd)
print("gathering random tensors\nbefore\b", rnd, "\nafter\n", rnd_gathered)
# so far this works as expected
print("now running a DDP model")
c = nn.Conv2d(2, 3, 3, 1, 1, 1, 1, True).cuda()
c = DDP(c)
opt = Adam(c.parameters())
bs = 5
if dist.get_rank() == 0:
bs = 4
inp = torch.rand((bs, 2, 5, 5)).cuda()
out = c(inp)
print("output_shape", out.shape)
out_gathered = awesome_allgather_function.apply(out)
print("output_shape_after_gather", out_gathered.shape)
# this also works
loss = out_gathered.sum()
loss.backward()
opt.step()
================================================
FILE: unetr_pp/utilities/file_conversions.py
================================================
from typing import Tuple, List, Union
from skimage import io
import SimpleITK as sitk
import numpy as np
import tifffile
def convert_2d_image_to_nifti(input_filename: str, output_filename_truncated: str, spacing=(999, 1, 1),
transform=None, is_seg: bool = False) -> None:
"""
Reads an image (must be a format that it recognized by skimage.io.imread) and converts it into a series of niftis.
The image can have an arbitrary number of input channels which will be exported separately (_0000.nii.gz,
_0001.nii.gz, etc for images and only .nii.gz for seg).
Spacing can be ignored most of the time.
!!!2D images are often natural images which do not have a voxel spacing that could be used for resampling. These images
must be resampled by you prior to converting them to nifti!!!
Datasets converted with this utility can only be used with the 2d U-Net configuration of nnU-Net
If Transform is not None it will be applied to the image after loading.
Segmentations will be converted to np.uint32!
:param is_seg:
:param transform:
:param input_filename:
:param output_filename_truncated: do not use a file ending for this one! Example: output_name='./converted/image1'. This
function will add the suffix (_0000) and file ending (.nii.gz) for you.
:param spacing:
:return:
"""
img = io.imread(input_filename)
if transform is not None:
img = transform(img)
if len(img.shape) == 2: # 2d image with no color channels
img = img[None, None] # add dimensions
else:
assert len(img.shape) == 3, "image should be 3d with color channel last but has shape %s" % str(img.shape)
# we assume that the color channel is the last dimension. Transpose it to be in first
img = img.transpose((2, 0, 1))
# add third dimension
img = img[:, None]
# image is now (c, x, x, z) where x=1 since it's 2d
if is_seg:
assert img.shape[0] == 1, 'segmentations can only have one color channel, not sure what happened here'
for j, i in enumerate(img):
if is_seg:
i = i.astype(np.uint32)
itk_img = sitk.GetImageFromArray(i)
itk_img.SetSpacing(list(spacing)[::-1])
if not is_seg:
sitk.WriteImage(itk_img, output_filename_truncated + "_%04.0d.nii.gz" % j)
else:
sitk.WriteImage(itk_img, output_filename_truncated + ".nii.gz")
def convert_3d_tiff_to_nifti(filenames: List[str], output_name: str, spacing: Union[tuple, list], transform=None, is_seg=False) -> None:
"""
filenames must be a list of strings, each pointing to a separate 3d tiff file. One file per modality. If your data
only has one imaging modality, simply pass a list with only a single entry
Files in filenames must be readable with
Note: we always only pass one file into tifffile.imread, not multiple (even though it supports it). This is because
I am not familiar enough with this functionality and would like to have control over what happens.
If Transform is not None it will be applied to the image after loading.
:param transform:
:param filenames:
:param output_name:
:param spacing:
:return:
"""
if is_seg:
assert len(filenames) == 1
for j, i in enumerate(filenames):
img = tifffile.imread(i)
if transform is not None:
img = transform(img)
itk_img = sitk.GetImageFromArray(img)
itk_img.SetSpacing(list(spacing)[::-1])
if not is_seg:
sitk.WriteImage(itk_img, output_name + "_%04.0d.nii.gz" % j)
else:
sitk.WriteImage(itk_img, output_name + ".nii.gz")
def convert_2d_segmentation_nifti_to_img(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8):
img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file))
assert img.shape[0] == 1, "This function can only export 2D segmentations!"
img = img[0]
if transform is not None:
img = transform(img)
io.imsave(output_filename, img.astype(export_dtype), check_contrast=False)
def convert_3d_segmentation_nifti_to_tiff(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8):
img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file))
assert len(img.shape) == 3, "This function can only export 3D segmentations!"
if transform is not None:
img = transform(img)
tifffile.imsave(output_filename, img.astype(export_dtype))
================================================
FILE: unetr_pp/utilities/file_endings.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
def remove_trailing_slash(filename: str):
while filename.endswith('/'):
filename = filename[:-1]
return filename
def maybe_add_0000_to_all_niigz(folder):
nii_gz = subfiles(folder, suffix='.nii.gz')
for n in nii_gz:
n = remove_trailing_slash(n)
if not n.endswith('_0000.nii.gz'):
os.rename(n, n[:-7] + '_0000.nii.gz')
================================================
FILE: unetr_pp/utilities/folder_names.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import network_training_output_dir
def get_output_folder_name(model: str, task: str = None, trainer: str = None, plans: str = None, fold: int = None,
overwrite_training_output_dir: str = None):
"""
Retrieves the correct output directory for the nnU-Net model described by the input parameters
:param model:
:param task:
:param trainer:
:param plans:
:param fold:
:param overwrite_training_output_dir:
:return:
"""
assert model in ["2d", "3d_cascade_fullres", '3d_fullres', '3d_lowres']
if overwrite_training_output_dir is not None:
tr_dir = overwrite_training_output_dir
else:
tr_dir = network_training_output_dir
current = join(tr_dir, model)
if task is not None:
current = join(current, task)
if trainer is not None and plans is not None:
current = join(current, trainer + "__" + plans)
if fold is not None:
current = join(current, "fold_%d" % fold)
return current
================================================
FILE: unetr_pp/utilities/nd_softmax.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import torch.nn.functional as F
softmax_helper = lambda x: F.softmax(x, 1)
================================================
FILE: unetr_pp/utilities/one_hot_encoding.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def to_one_hot(seg, all_seg_labels=None):
if all_seg_labels is None:
all_seg_labels = np.unique(seg)
result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype)
for i, l in enumerate(all_seg_labels):
result[i][seg == l] = 1
return result
================================================
FILE: unetr_pp/utilities/overlay_plots.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from multiprocessing.pool import Pool
import numpy as np
import SimpleITK as sitk
from unetr_pp.utilities.task_name_id_conversion import convert_task_name_to_id, convert_id_to_task_name
from batchgenerators.utilities.file_and_folder_operations import *
from unetr_pp.paths import *
color_cycle = (
"000000",
"4363d8",
"f58231",
"3cb44b",
"e6194B",
"911eb4",
"ffe119",
"bfef45",
"42d4f4",
"f032e6",
"000075",
"9A6324",
"808000",
"800000",
"469990",
)
def hex_to_rgb(hex: str):
assert len(hex) == 6
return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))
def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, color_cycle=color_cycle,
overlay_intensity=0.6):
"""
image must be a color image, so last dimension must be 3. if image is grayscale, tile it first!
Segmentation must be label map of same shape as image (w/o color channels)
mapping can be label_id -> idx_in_cycle or None
returned image is scaled to [0, 255]!!!
"""
# assert len(image.shape) == len(segmentation.shape)
# assert all([i == j for i, j in zip(image.shape, segmentation.shape)])
# create a copy of image
image = np.copy(input_image)
if len(image.shape) == 2:
image = np.tile(image[:, :, None], (1, 1, 3))
elif len(image.shape) == 3:
assert image.shape[2] == 3, 'if 3d image is given the last dimension must be the color channels ' \
'(3 channels). Only 2D images are supported'
else:
raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in "
"last dimension) are supported")
# rescale image to [0, 255]
image = image - image.min()
image = image / image.max() * 255
# create output
if mapping is None:
uniques = np.unique(segmentation)
mapping = {i: c for c, i in enumerate(uniques)}
for l in mapping.keys():
image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))
# rescale result to [0, 255]
image = image / image.max() * 255
return image.astype(np.uint8)
def plot_overlay(image_file: str, segmentation_file: str, output_file: str, overlay_intensity: float = 0.6):
import matplotlib.pyplot as plt
image = sitk.GetArrayFromImage(sitk.ReadImage(image_file))
seg = sitk.GetArrayFromImage(sitk.ReadImage(segmentation_file))
assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % (
image_file, segmentation_file)
assert len(image.shape) == 3, 'only 3D images/segs are supported'
fg_mask = seg != 0
fg_per_slice = fg_mask.sum((1, 2))
selected_slice = np.argmax(fg_per_slice)
overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
plt.imsave(output_file, overlay)
def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, modality_index=0):
import matplotlib.pyplot as plt
data = np.load(case_file)['data']
assert modality_index < (data.shape[0] - 1), 'This dataset only supports modality index up to %d' % (data.shape[0] - 2)
image = data[modality_index]
seg = data[-1]
seg[seg < 0] = 0
fg_mask = seg > 0
fg_per_slice = fg_mask.sum((1, 2))
selected_slice = np.argmax(fg_per_slice)
overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
plt.imsave(output_file, overlay)
def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, list_of_output_files, overlay_intensity,
num_processes=8):
p = Pool(num_processes)
r = p.starmap_async(plot_overlay, zip(
list_of_image_files, list_of_seg_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files)
))
r.get()
p.close()
p.join()
def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity,
num_processes=8, modality_index=0):
p = Pool(num_processes)
r = p.starmap_async(plot_overlay_preprocessed, zip(
list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files),
[modality_index] * len(list_of_output_files)
))
r.get()
p.close()
p.join()
def generate_overlays_for_task(task_name_or_id, output_folder, num_processes=8, modality_idx=0, use_preprocessed=True,
data_identifier=default_data_identifier):
if isinstance(task_name_or_id, str):
if not task_name_or_id.startswith("Task"):
task_name_or_id = int(task_name_or_id)
task_name = convert_id_to_task_name(task_name_or_id)
else:
task_name = task_name_or_id
else:
task_name = convert_id_to_task_name(int(task_name_or_id))
if not use_preprocessed:
folder = join(nnFormer_raw_data, task_name)
identifiers = [i[:-7] for i in subfiles(join(folder, 'labelsTr'), suffix='.nii.gz', join=False)]
image_files = [join(folder, 'imagesTr', i + "_%04.0d.nii.gz" % modality_idx) for i in identifiers]
seg_files = [join(folder, 'labelsTr', i + ".nii.gz") for i in identifiers]
assert all([isfile(i) for i in image_files])
assert all([isfile(i) for i in seg_files])
maybe_mkdir_p(output_folder)
output_files = [join(output_folder, i + '.png') for i in identifiers]
multiprocessing_plot_overlay(image_files, seg_files, output_files, 0.6, num_processes)
else:
folder = join(preprocessing_output_dir, task_name)
if not isdir(folder): raise RuntimeError("run preprocessing for that task first")
matching_folders = subdirs(folder, prefix=data_identifier + "_stage")
if len(matching_folders) == 0: "run preprocessing for that task first (use default experiment planner!)"
matching_folders.sort()
folder = matching_folders[-1]
identifiers = [i[:-4] for i in subfiles(folder, suffix='.npz', join=False)]
maybe_mkdir_p(output_folder)
output_files = [join(output_folder, i + '.png') for i in identifiers]
image_files = [join(folder, i + ".npz") for i in identifiers]
maybe_mkdir_p(output_folder)
multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=0.6,
num_processes=num_processes, modality_index=modality_idx)
def entry_point_generate_overlay():
import argparse
parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this "
"disregards spacing information!")
parser.add_argument('-t', type=str, help="task name or task ID", required=True)
parser.add_argument('-o', type=str, help="output folder", required=True)
parser.add_argument('-num_processes', type=int, default=8, required=False, help="number of processes used. Default: 8")
parser.add_argument('-modality_idx', type=int, default=0, required=False,
help="modality index used (0 = _0000.nii.gz). Default: 0")
parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else "
"we use preprocessed")
args = parser.parse_args()
generate_overlays_for_task(args.t, args.o, args.num_processes, args.modality_idx, use_preprocessed=not args.use_raw)
================================================
FILE: unetr_pp/utilities/random_stuff.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class no_op(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
================================================
FILE: unetr_pp/utilities/recursive_delete_npz.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
import argparse
import os
def recursive_delete_npz(current_directory: str):
npz_files = subfiles(current_directory, join=True, suffix=".npz")
npz_files = [i for i in npz_files if not i.endswith("segFromPrevStage.npz")] # to be extra safe
_ = [os.remove(i) for i in npz_files]
for d in subdirs(current_directory, join=False):
if d != "pred_next_stage":
recursive_delete_npz(join(current_directory, d))
if __name__ == "__main__":
parser = argparse.ArgumentParser(usage="USE THIS RESPONSIBLY! DANGEROUS! I (Fabian) use this to remove npz files "
"after I ran figure_out_what_to_submit")
parser.add_argument("-f", help="folder", required=True)
args = parser.parse_args()
recursive_delete_npz(args.f)
================================================
FILE: unetr_pp/utilities/recursive_rename_taskXX_to_taskXXX.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from batchgenerators.utilities.file_and_folder_operations import *
import os
def recursive_rename(folder):
s = subdirs(folder, join=False)
for ss in s:
if ss.startswith("Task") and ss.find("_") == 6:
task_id = int(ss[4:6])
name = ss[7:]
os.rename(join(folder, ss), join(folder, "Task%03.0d_" % task_id + name))
s = subdirs(folder, join=True)
for ss in s:
recursive_rename(ss)
if __name__ == "__main__":
recursive_rename("/media/fabian/Results/nnFormer")
recursive_rename("/media/fabian/unetr_pp")
recursive_rename("/media/fabian/My Book/MedicalDecathlon")
recursive_rename("/home/fabian/drives/datasets/nnFormer_raw")
recursive_rename("/home/fabian/drives/datasets/nnFormer_preprocessed")
recursive_rename("/home/fabian/drives/datasets/nnFormer_testSets")
recursive_rename("/home/fabian/drives/datasets/results/nnFormer")
recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/Decathlon_raw")
recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/nnFormer_preprocessed")
================================================
FILE: unetr_pp/utilities/sitk_stuff.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import SimpleITK as sitk
def copy_geometry(image: sitk.Image, ref: sitk.Image):
image.SetOrigin(ref.GetOrigin())
image.SetDirection(ref.GetDirection())
image.SetSpacing(ref.GetSpacing())
return image
================================================
FILE: unetr_pp/utilities/task_name_id_conversion.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unetr_pp.paths import nnFormer_raw_data, preprocessing_output_dir, nnFormer_cropped_data, network_training_output_dir
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
def convert_id_to_task_name(task_id: int):
startswith = "Task%03.0d" % task_id
if preprocessing_output_dir is not None:
candidates_preprocessed = subdirs(preprocessing_output_dir, prefix=startswith, join=False)
else:
candidates_preprocessed = []
if nnFormer_raw_data is not None:
candidates_raw = subdirs(nnFormer_raw_data, prefix=startswith, join=False)
else:
candidates_raw = []
if nnFormer_cropped_data is not None:
candidates_cropped = subdirs(nnFormer_cropped_data, prefix=startswith, join=False)
else:
candidates_cropped = []
candidates_trained_models = []
if network_training_output_dir is not None:
for m in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']:
if isdir(join(network_training_output_dir, m)):
candidates_trained_models += subdirs(join(network_training_output_dir, m), prefix=startswith, join=False)
all_candidates = candidates_cropped + candidates_preprocessed + candidates_raw + candidates_trained_models
unique_candidates = np.unique(all_candidates)
if len(unique_candidates) > 1:
raise RuntimeError("More than one task name found for task id %d. Please correct that. (I looked in the "
"following folders:\n%s\n%s\n%s" % (task_id, nnFormer_raw_data, preprocessing_output_dir,
nnFormer_cropped_data))
if len(unique_candidates) == 0:
raise RuntimeError("Could not find a task with the ID %d. Make sure the requested task ID exists and that "
"nnU-Net knows where raw and preprocessed data are located (see Documentation - "
"Installation). Here are your currently defined folders:\nunetr_pp_preprocessed=%s\nRESULTS_"
"FOLDER=%s\nunetr_pp_raw_data_base=%s\nIf something is not right, adapt your environemnt "
"variables." %
(task_id,
os.environ.get('unetr_pp_preprocessed') if os.environ.get('unetr_pp_preprocessed') is not None else 'None',
os.environ.get('RESULTS_FOLDER') if os.environ.get('RESULTS_FOLDER') is not None else 'None',
os.environ.get('unetr_pp_raw_data_base') if os.environ.get('unetr_pp_raw_data_base') is not None else 'None',
))
return unique_candidates[0]
def convert_task_name_to_id(task_name: str):
assert task_name.startswith("Task")
task_id = int(task_name[4:7])
return task_id
================================================
FILE: unetr_pp/utilities/tensor_utilities.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import nn
def sum_tensor(inp, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
inp = inp.sum(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
inp = inp.sum(int(ax))
return inp
def mean_tensor(inp, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
inp = inp.mean(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
inp = inp.mean(int(ax))
return inp
def flip(x, dim):
"""
flips the tensor at dimension dim (mirroring!)
:param x:
:param dim:
:return:
"""
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
================================================
FILE: unetr_pp/utilities/to_torch.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def maybe_to_torch(d):
if isinstance(d, list):
d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
elif not isinstance(d, torch.Tensor):
d = torch.from_numpy(d).float()
return d
def to_cuda(data, non_blocking=True, gpu_id=0):
if isinstance(data, list):
data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
else:
data = data.cuda(gpu_id, non_blocking=non_blocking)
return data