Repository: Amshaker/unetr_plus_plus Branch: main Commit: dcba1c6e8179 Files: 155 Total size: 1.2 MB Directory structure: gitextract_3co428a0/ ├── LICENSE ├── README.md ├── evaluation_scripts/ │ ├── run_evaluation_acdc.sh │ ├── run_evaluation_lung.sh │ ├── run_evaluation_synapse.sh │ └── run_evaluation_tumor.sh ├── requirements.txt ├── training_scripts/ │ ├── run_training_acdc.sh │ ├── run_training_lung.sh │ ├── run_training_synapse.sh │ └── run_training_tumor.sh └── unetr_pp/ ├── __init__.py ├── configuration.py ├── evaluation/ │ ├── __init__.py │ ├── add_dummy_task_with_mean_over_all_tasks.py │ ├── add_mean_dice_to_json.py │ ├── collect_results_files.py │ ├── evaluator.py │ ├── metrics.py │ ├── model_selection/ │ │ ├── __init__.py │ │ ├── collect_all_fold0_results_and_summarize_in_one_csv.py │ │ ├── ensemble.py │ │ ├── figure_out_what_to_submit.py │ │ ├── rank_candidates.py │ │ ├── rank_candidates_StructSeg.py │ │ ├── rank_candidates_cascade.py │ │ ├── summarize_results_in_one_json.py │ │ └── summarize_results_with_plans.py │ ├── region_based_evaluation.py │ ├── surface_dice.py │ ├── unetr_pp_acdc_checkpoint/ │ │ └── unetr_pp/ │ │ └── 3d_fullres/ │ │ └── Task001_ACDC/ │ │ └── unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/ │ │ └── fold_0/ │ │ └── .gitignore │ ├── unetr_pp_lung_checkpoint/ │ │ └── unetr_pp/ │ │ └── 3d_fullres/ │ │ └── Task006_Lung/ │ │ └── unetr_pp_trainer_lung__unetr_pp_Plansv2.1/ │ │ └── fold_0/ │ │ └── .gitignore │ ├── unetr_pp_synapse_checkpoint/ │ │ └── unetr_pp/ │ │ └── 3d_fullres/ │ │ └── Task002_Synapse/ │ │ └── unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/ │ │ └── fold_0/ │ │ └── .gitignore │ └── unetr_pp_tumor_checkpoint/ │ └── unetr_pp/ │ └── 3d_fullres/ │ └── Task003_tumor/ │ └── unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/ │ └── fold_0/ │ └── .gitignore ├── experiment_planning/ │ ├── DatasetAnalyzer.py │ ├── __init__.py │ ├── alternative_experiment_planning/ │ │ ├── experiment_planner_baseline_3DUNet_v21_11GB.py │ │ ├── experiment_planner_baseline_3DUNet_v21_16GB.py │ │ ├── experiment_planner_baseline_3DUNet_v21_32GB.py │ │ ├── experiment_planner_baseline_3DUNet_v21_3convperstage.py │ │ ├── experiment_planner_baseline_3DUNet_v22.py │ │ ├── experiment_planner_baseline_3DUNet_v23.py │ │ ├── experiment_planner_residual_3DUNet_v21.py │ │ ├── normalization/ │ │ │ ├── experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py │ │ │ ├── experiment_planner_3DUNet_CT2.py │ │ │ └── experiment_planner_3DUNet_nonCT.py │ │ ├── patch_size/ │ │ │ ├── experiment_planner_3DUNet_isotropic_in_mm.py │ │ │ └── experiment_planner_3DUNet_isotropic_in_voxels.py │ │ ├── pooling_and_convs/ │ │ │ ├── experiment_planner_baseline_3DUNet_allConv3x3.py │ │ │ └── experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py │ │ └── target_spacing/ │ │ ├── experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py │ │ ├── experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py │ │ └── experiment_planner_baseline_3DUNet_v21_noResampling.py │ ├── change_batch_size.py │ ├── common_utils.py │ ├── experiment_planner_baseline_2DUNet.py │ ├── experiment_planner_baseline_2DUNet_v21.py │ ├── experiment_planner_baseline_3DUNet.py │ ├── experiment_planner_baseline_3DUNet_v21.py │ ├── nnFormer_convert_decathlon_task.py │ ├── nnFormer_plan_and_preprocess.py │ ├── summarize_plans.py │ └── utils.py ├── inference/ │ ├── __init__.py │ ├── inferTs/ │ │ └── swin_nomask_2/ │ │ └── plans.pkl │ ├── predict.py │ ├── predict_simple.py │ └── segmentation_export.py ├── inference_acdc.py ├── inference_synapse.py ├── inference_tumor.py ├── network_architecture/ │ ├── README.md │ ├── __init__.py │ ├── acdc/ │ │ ├── __init__.py │ │ ├── model_components.py │ │ ├── transformerblock.py │ │ └── unetr_pp_acdc.py │ ├── dynunet_block.py │ ├── generic_UNet.py │ ├── initialization.py │ ├── layers.py │ ├── lung/ │ │ ├── __init__.py │ │ ├── model_components.py │ │ ├── transformerblock.py │ │ └── unetr_pp_lung.py │ ├── neural_network.py │ ├── synapse/ │ │ ├── __init__.py │ │ ├── model_components.py │ │ ├── transformerblock.py │ │ └── unetr_pp_synapse.py │ └── tumor/ │ ├── __init__.py │ ├── model_components.py │ ├── transformerblock.py │ └── unetr_pp_tumor.py ├── paths.py ├── postprocessing/ │ ├── connected_components.py │ ├── consolidate_all_for_paper.py │ ├── consolidate_postprocessing.py │ └── consolidate_postprocessing_simple.py ├── preprocessing/ │ ├── cropping.py │ ├── custom_preprocessors/ │ │ └── preprocessor_scale_RGB_to_0_1.py │ ├── preprocessing.py │ └── sanity_checks.py ├── run/ │ ├── __init__.py │ ├── default_configuration.py │ └── run_training.py ├── training/ │ ├── __init__.py │ ├── cascade_stuff/ │ │ ├── __init__.py │ │ └── predict_next_stage.py │ ├── data_augmentation/ │ │ ├── __init__.py │ │ ├── custom_transforms.py │ │ ├── data_augmentation_insaneDA.py │ │ ├── data_augmentation_insaneDA2.py │ │ ├── data_augmentation_moreDA.py │ │ ├── data_augmentation_noDA.py │ │ ├── default_data_augmentation.py │ │ ├── downsampling.py │ │ └── pyramid_augmentations.py │ ├── dataloading/ │ │ ├── __init__.py │ │ └── dataset_loading.py │ ├── learning_rate/ │ │ └── poly_lr.py │ ├── loss_functions/ │ │ ├── TopK_loss.py │ │ ├── __init__.py │ │ ├── crossentropy.py │ │ ├── deep_supervision.py │ │ └── dice_loss.py │ ├── model_restore.py │ ├── network_training/ │ │ ├── Trainer_acdc.py │ │ ├── Trainer_lung.py │ │ ├── Trainer_synapse.py │ │ ├── Trainer_tumor.py │ │ ├── network_trainer_acdc.py │ │ ├── network_trainer_lung.py │ │ ├── network_trainer_synapse.py │ │ ├── network_trainer_tumor.py │ │ ├── unetr_pp_trainer_acdc.py │ │ ├── unetr_pp_trainer_lung.py │ │ ├── unetr_pp_trainer_synapse.py │ │ └── unetr_pp_trainer_tumor.py │ └── optimizer/ │ └── ranger.py └── utilities/ ├── __init__.py ├── distributed.py ├── file_conversions.py ├── file_endings.py ├── folder_names.py ├── nd_softmax.py ├── one_hot_encoding.py ├── overlay_plots.py ├── random_stuff.py ├── recursive_delete_npz.py ├── recursive_rename_taskXX_to_taskXXX.py ├── sitk_stuff.py ├── task_name_id_conversion.py ├── tensor_utilities.py └── to_torch.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2024] [Abdelrahman Mohamed Shaker Youssief] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation ![](https://i.imgur.com/waxVImv.png) [Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ)*1, [Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra)1, [Hanoona Rasheed](https://scholar.google.com/citations?user=yhDdEuEAAAAJ&hl=en&authuser=1&oi=sra)1, [Salman Khan](https://salman-h-khan.github.io/)1, [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en)2,3 and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en)1,4 Mohamed Bin Zayed University of Artificial Intelligence1, University of California Merced2, Google Research3, Linkoping University4 [![paper](https://img.shields.io/badge/Paper-.svg)](https://ieeexplore.ieee.org/document/10526382) [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://amshaker.github.io/unetr_plus_plus/) [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](https://docs.google.com/presentation/d/e/2PACX-1vRtrxSfA2kU1fBmPxBdMQioLfsvjcjWBoaOVf3aupqajm0mw_C4TEz05Yk4ZF_vqoMyA8iiUJE60ynm/pub?start=true&loop=false&delayms=10000) ## :rocket: News * **(Dec 2025):** UNETR++ is now available in Keras 3 with a simple model initialization as part of [AI Toolkit for Healthcare Imaging](https://github.com/innat/medic-ai/tree/main/medicai/models/unetr_plus_plus)! * **(Oct 2025):** UNETR++ became the **#1 most popular article** of 2025 in [IEEE Transactions on Medical Imaging (TMI)](https://ieeexplore.ieee.org/xpl/topAccessedArticles.jsp?punumber=42)! 🎉 * **(May 04, 2024):** We're thrilled to share that UNETR++ has been accepted to IEEE TMI-2024! 🎊 * **(Jun 01, 2023):** UNETR++ code & weights are released for Decathlon-Lung and BRaTs. * **(Dec 15, 2022):** UNETR++ weights are released for Synapse & ACDC datasets. * **(Dec 09, 2022):** UNETR++ training and evaluation codes are released.
![main figure](media/intro_fig.png) > **Abstract:** *Owing to the success of transformer models, recent works study their applicability in 3D medical segmentation tasks. Within the transformer models, the self-attention mechanism is one of the main building blocks that strives to capture long-range dependencies. However, the self-attention operation has quadratic complexity, which proves to be a computational bottleneck, especially in volumetric medical imaging, where the inputs are 3D with numerous slices. In this paper, we propose a 3D medical image segmentation approach, named UNETR++, that offers both high-quality segmentation masks as well as efficiency in terms of parameters, compute cost, and inference speed. The core of our design is the introduction of a novel efficient paired attention (EPA) block that efficiently learns spatial and channel-wise discriminative features using a pair of inter-dependent branches based on spatial and channel attention. Our spatial attention formulation is efficient, having linear complexity with respect to the input sequence length. To enable communication between spatial and channel-focused branches, we share the weights of query and key mapping functions that provide a complementary benefit (paired attention), while also reducing the overall network parameters. Our extensive evaluations on five benchmarks, Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung, reveal the effectiveness of our contributions in terms of both efficiency and accuracy. On Synapse, our UNETR++ sets a new state-of-the-art with a Dice Score of 87.2%, while being significantly efficient with a reduction of over 71% in terms of both parameters and FLOPs, compared to the best method in the literature.*
## Architecture overview of UNETR++ Overview of our UNETR++ approach with hierarchical encoder-decoder structure. The 3D patches are fed to the encoder, whose outputs are then connected to the decoder via skip connections followed by convolutional blocks to produce the final segmentation mask. The focus of our design is the introduction of an _efficient paired-attention_ (EPA) block. Each EPA block performs two tasks using parallel attention modules with shared keys-queries and different value layers to efficiently learn enriched spatial-channel feature representations. As illustrated in the EPA block diagram (on the right), the first (top) attention module aggregates the spatial features by a weighted sum of the projected features in a linear manner to compute the spatial attention maps, while the second (bottom) attention module emphasizes the dependencies in the channels and computes the channel attention maps. Finally, the outputs of the two attention modules are fused and passed to convolutional blocks to enhance the feature representation, leading to better segmentation masks. ![Architecture overview](media/UNETR++_Block_Diagram.jpg)
## Results ### Synapse Dataset State-of-the-art comparison on the abdominal multi-organ Synapse dataset. We report both the segmentation performance (DSC, HD95) and model complexity (parameters and FLOPs). Our proposed UNETR++ achieves favorable segmentation performance against existing methods, while considerably reducing the model complexity. Best results are in bold. Abbreviations stand for: Spl: _spleen_, RKid: _right kidney_, LKid: _left kidney_, Gal: _gallbladder_, Liv: _liver_, Sto: _stomach_, Aor: _aorta_, Pan: _pancreas_. Best results are in bold. ![Synapse Results](media/synapse_results.png)
## Qualitative Comparison ### Synapse Dataset Qualitative comparison on multi-organ segmentation task. Here, we compare our UNETR++ with existing methods: UNETR, Swin UNETR, and nnFormer. The different abdominal organs are shown in the legend below the examples. Existing methods struggle to correctly segment different organs (marked in red dashed box). Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs. ![Synapse Qual Results](media/UNETR++_results_fig_synapse.jpg) ### ACDC Dataset Qualitative comparison on the ACDC dataset. We compare our UNETR++ with existing methods: UNETR and nnFormer. It is noticeable that the existing methods struggle to correctly segment different organs (marked in red dashed box). Our UNETR++ achieves favorable segmentation performance by accurately segmenting the organs. Our UNETR++ achieves promising segmentation performance by accurately segmenting the organs. ![ACDC Qual Results](media/acdc_vs_unetr_suppl.jpg)
## Installation The code is tested with PyTorch 1.11.0 and CUDA 11.3. After cloning the repository, follow the below steps for installation, 1. Create and activate conda environment ```shell conda create --name unetr_pp python=3.8 conda activate unetr_pp ``` 2. Install PyTorch and torchvision ```shell pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 ``` 3. Install other dependencies ```shell pip install -r requirements.txt ```
## Dataset We follow the same dataset preprocessing as in [nnFormer](https://github.com/282857341/nnFormer). We conducted extensive experiments on five benchmarks: Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung. The dataset folders for Synapse should be organized as follows: ``` ./DATASET_Synapse/ ├── unetr_pp_raw/ ├── unetr_pp_raw_data/ ├── Task02_Synapse/ ├── imagesTr/ ├── imagesTs/ ├── labelsTr/ ├── labelsTs/ ├── dataset.json ├── Task002_Synapse ├── unetr_pp_cropped_data/ ├── Task002_Synapse ``` The dataset folders for ACDC should be organized as follows: ``` ./DATASET_Acdc/ ├── unetr_pp_raw/ ├── unetr_pp_raw_data/ ├── Task01_ACDC/ ├── imagesTr/ ├── imagesTs/ ├── labelsTr/ ├── labelsTs/ ├── dataset.json ├── Task001_ACDC ├── unetr_pp_cropped_data/ ├── Task001_ACDC ``` The dataset folders for Decathlon-Lung should be organized as follows: ``` ./DATASET_Lungs/ ├── unetr_pp_raw/ ├── unetr_pp_raw_data/ ├── Task06_Lung/ ├── imagesTr/ ├── imagesTs/ ├── labelsTr/ ├── labelsTs/ ├── dataset.json ├── Task006_Lung ├── unetr_pp_cropped_data/ ├── Task006_Lung ``` The dataset folders for BRaTs should be organized as follows: ``` ./DATASET_Tumor/ ├── unetr_pp_raw/ ├── unetr_pp_raw_data/ ├── Task03_tumor/ ├── imagesTr/ ├── imagesTs/ ├── labelsTr/ ├── labelsTs/ ├── dataset.json ├── Task003_tumor ├── unetr_pp_cropped_data/ ├── Task003_tumor ``` Please refer to [Setting up the datasets](https://github.com/282857341/nnFormer) on nnFormer repository for more details. Alternatively, you can download the preprocessed dataset for [Synapse](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EbHDhSjkQW5Ak9SMPnGCyb8BOID98wdg3uUvQ0eNvTZ8RA?e=YVhfdg), [ACDC](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EY9qieTkT3JFrhCJQiwZXdsB1hJ4ebVAtNdBNOs2HAo3CQ?e=VwfFHC), [Decathlon-Lung](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EWhU1T7c-mNKgkS2PQjFwP0B810LCiX3D2CvCES2pHDVSg?e=OqcIW3), [BRaTs](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/EaQOxpD2yE5Btl-UEBAbQa0BYFBCL4J2Ph-VF_sqZlBPSQ?e=DFY41h), and extract it under the project directory. ## Training The following scripts can be used for training our UNETR++ model on the datasets: ```shell bash training_scripts/run_training_synapse.sh bash training_scripts/run_training_acdc.sh bash training_scripts/run_training_lung.sh bash training_scripts/run_training_tumor.sh ```
## Evaluation To reproduce the results of UNETR++: 1- Download [Synapse weights](https://drive.google.com/file/d/13JuLMeDQRR_a3c3tr2V2oav6I29fJoBa) and paste ```model_final_checkpoint.model``` in the following path: ```shell unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/ ``` Then, run ```shell bash evaluation_scripts/run_evaluation_synapse.sh ``` 2- Download [ACDC weights](https://drive.google.com/file/d/15YXiHai1zLc1ycmXaiSHetYbLGum3tV5) and paste ```model_final_checkpoint.model``` it in the following path: ```shell unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/ ``` Then, run ```shell bash evaluation_scripts/run_evaluation_acdc.sh ``` 3- Download [Decathlon-Lung weights](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/abdelrahman_youssief_mbzuai_ac_ae/ETAlc8WTjV1BhZx7zwFpA8UBS4og6upb1qX2UKkypMoTjw?e=KfzAiG) and paste ```model_final_checkpoint.model``` it in the following path: ```shell unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/ ``` Then, run ```shell bash evaluation_scripts/run_evaluation_lung.sh ``` 4- Download [BRaTs weights](https://drive.google.com/file/d/1LiqnVKKv3DrDKvo6J0oClhIFirhaz5PG) and paste ```model_final_checkpoint.model``` it in the following path: ```shell unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/ ``` Then, run ```shell bash evaluation_scripts/run_evaluation_tumor.sh ```
## Acknowledgement This repository is built based on [nnFormer](https://github.com/282857341/nnFormer) repository. ## Citation If you use our work, please consider citing: ```bibtex @ARTICLE{10526382, title={UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation}, author={Shaker, Abdelrahman M. and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz}, journal={IEEE Transactions on Medical Imaging}, year={2024}, doi={10.1109/TMI.2024.3398728}} ``` ## Contact Should you have any question, please create an issue on this repository or contact me at abdelrahman.youssief@mbzuai.ac.ae. ================================================ FILE: evaluation_scripts/run_evaluation_acdc.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Acdc CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_acdc_checkpoint export PYTHONPATH=.././ export RESULTS_FOLDER="$CHECKPOINT_PATH" export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task01_ACDC export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_acdc 1 0 -val ================================================ FILE: evaluation_scripts/run_evaluation_lung.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Lungs CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_lung_checkpoint export PYTHONPATH=.././ export RESULTS_FOLDER="$CHECKPOINT_PATH" export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_lung 6 0 -val ================================================ FILE: evaluation_scripts/run_evaluation_synapse.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Synapse CHECKPOINT_PATH=../unetr_pp/evaluation/unetr_pp_synapse_checkpoint export PYTHONPATH=.././ export RESULTS_FOLDER="$CHECKPOINT_PATH" export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task02_Synapse export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_synapse 2 0 -val ================================================ FILE: evaluation_scripts/run_evaluation_tumor.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Tumor export PYTHONPATH=.././ export RESULTS_FOLDER=../unetr_pp/evaluation/unetr_pp_tumor_checkpoint export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw # Only for Tumor, it is recommended to train unetr_plus_plus first, and then use the provided checkpoint to evaluate. It might raise issues regarding the pickle files if you evaluated without training python ../unetr_pp/inference/predict_simple.py -i ../unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task003_tumor/imagesTs -o ../unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/inferTs -m 3d_fullres -t 3 -f 0 -chk model_final_checkpoint -tr unetr_pp_trainer_tumor python ../unetr_pp/inference_tumor.py 0 ================================================ FILE: requirements.txt ================================================ argparse==1.4.0 numpy==1.20.1 batchgenerators==0.21 matplotlib==3.5.1 typing==3.7.4.3 sklearn==0.0 scikit-learn==1.0.2 tqdm==4.32.1 fvcore==0.1.5.post20220414 scikit-image==0.18.1 simpleitk==2.2.0 tifffile==2021.3.31 medpy==0.4.0 pandas==1.2.3 scipy==1.6.2 nibabel==4.0.1 timm==0.4.12 monai==0.7.0 einops==0.6.0 tensorboardX==2.2 ================================================ FILE: training_scripts/run_training_acdc.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Acdc export PYTHONPATH=.././ export RESULTS_FOLDER=../output_acdc export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task01_ACDC export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_acdc 1 0 ================================================ FILE: training_scripts/run_training_lung.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Lungs export PYTHONPATH=.././ export RESULTS_FOLDER=../output_lung export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task06_Lung export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_lung 6 0 ================================================ FILE: training_scripts/run_training_synapse.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Synapse export PYTHONPATH=.././ export RESULTS_FOLDER=../output_synapse export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task02_Synapse export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_synapse 2 0 ================================================ FILE: training_scripts/run_training_tumor.sh ================================================ #!/bin/sh DATASET_PATH=../DATASET_Tumor export PYTHONPATH=.././ export RESULTS_FOLDER=../output_tumor export unetr_pp_preprocessed="$DATASET_PATH"/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor export unetr_pp_raw_data_base="$DATASET_PATH"/unetr_pp_raw python ../unetr_pp/run/run_training.py 3d_fullres unetr_pp_trainer_tumor 3 0 ================================================ FILE: unetr_pp/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/configuration.py ================================================ import os default_num_threads = 8 if 'nnFormer_def_n_proc' not in os.environ else int(os.environ['nnFormer_def_n_proc']) RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD = 3 # determines what threshold to use for resampling the low resolution axis # separately (with NN) ================================================ FILE: unetr_pp/evaluation/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/evaluation/add_dummy_task_with_mean_over_all_tasks.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import numpy as np from batchgenerators.utilities.file_and_folder_operations import subfiles import os from collections import OrderedDict folder = "/home/fabian/drives/E132-Projekte/Projects/2018_MedicalDecathlon/Leaderboard" task_descriptors = ['2D final 2', '2D final, less pool, dc and topK, fold0', '2D final pseudo3d 7, fold0', '2D final, less pool, dc and ce, fold0', '3D stage0 final 2, fold0', '3D fullres final 2, fold0'] task_ids_with_no_stage0 = ["Task001_BrainTumour", "Task004_Hippocampus", "Task005_Prostate"] mean_scores = OrderedDict() for t in task_descriptors: mean_scores[t] = OrderedDict() json_files = subfiles(folder, True, None, ".json", True) json_files = [i for i in json_files if not i.split("/")[-1].startswith(".")] # stupid mac for j in json_files: with open(j, 'r') as f: res = json.load(f) task = res['task'] if task != "Task999_ALL": name = res['name'] if name in task_descriptors: if task not in list(mean_scores[name].keys()): mean_scores[name][task] = res['results']['mean']['mean'] else: raise RuntimeError("duplicate task %s for description %s" % (task, name)) for t in task_ids_with_no_stage0: mean_scores["3D stage0 final 2, fold0"][t] = mean_scores["3D fullres final 2, fold0"][t] a = set() for i in mean_scores.keys(): a = a.union(list(mean_scores[i].keys())) for i in mean_scores.keys(): try: for t in list(a): assert t in mean_scores[i].keys(), "did not find task %s for experiment %s" % (t, i) new_res = OrderedDict() new_res['name'] = i new_res['author'] = "Fabian" new_res['task'] = "Task999_ALL" new_res['results'] = OrderedDict() new_res['results']['mean'] = OrderedDict() new_res['results']['mean']['mean'] = OrderedDict() tasks = list(mean_scores[i].keys()) metrics = mean_scores[i][tasks[0]].keys() for m in metrics: foreground_values = [mean_scores[i][n][m] for n in tasks] new_res['results']['mean']["mean"][m] = np.nanmean(foreground_values) output_fname = i.replace(" ", "_") + "_globalMean.json" with open(os.path.join(folder, output_fname), 'w') as f: json.dump(new_res, f) except AssertionError: print("could not process experiment %s" % i) print("did not find task %s for experiment %s" % (t, i)) ================================================ FILE: unetr_pp/evaluation/add_mean_dice_to_json.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import numpy as np from batchgenerators.utilities.file_and_folder_operations import subfiles from collections import OrderedDict def foreground_mean(filename): with open(filename, 'r') as f: res = json.load(f) class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean')]) class_ids = class_ids[class_ids != 0] class_ids = class_ids[class_ids != -1] class_ids = class_ids[class_ids != 99] tmp = res['results']['mean'].get('99') if tmp is not None: _ = res['results']['mean'].pop('99') metrics = res['results']['mean']['1'].keys() res['results']['mean']["mean"] = OrderedDict() for m in metrics: foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids] res['results']['mean']["mean"][m] = np.nanmean(foreground_values) with open(filename, 'w') as f: json.dump(res, f, indent=4, sort_keys=True) def run_in_folder(folder): json_files = subfiles(folder, True, None, ".json", True) json_files = [i for i in json_files if not i.split("/")[-1].startswith(".") and not i.endswith("_globalMean.json")] # stupid mac for j in json_files: foreground_mean(j) if __name__ == "__main__": folder = "/media/fabian/Results/nnFormerOutput_final/summary_jsons" run_in_folder(folder) ================================================ FILE: unetr_pp/evaluation/collect_results_files.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from batchgenerators.utilities.file_and_folder_operations import subdirs, subfiles def crawl_and_copy(current_folder, out_folder, prefix="fabian_", suffix="ummary.json"): """ This script will run recursively through all subfolders of current_folder and copy all files that end with suffix with some automatically generated prefix into out_folder :param current_folder: :param out_folder: :param prefix: :return: """ s = subdirs(current_folder, join=False) f = subfiles(current_folder, join=False) f = [i for i in f if i.endswith(suffix)] if current_folder.find("fold0") != -1: for fl in f: shutil.copy(os.path.join(current_folder, fl), os.path.join(out_folder, prefix+fl)) for su in s: if prefix == "": add = su else: add = "__" + su crawl_and_copy(os.path.join(current_folder, su), out_folder, prefix=prefix+add) if __name__ == "__main__": from unetr_pp.paths import network_training_output_dir output_folder = "/home/fabian/PhD/results/nnFormerV2/leaderboard" crawl_and_copy(network_training_output_dir, output_folder) from unetr_pp.evaluation.add_mean_dice_to_json import run_in_folder run_in_folder(output_folder) ================================================ FILE: unetr_pp/evaluation/evaluator.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import inspect import json import hashlib from datetime import datetime from multiprocessing.pool import Pool import numpy as np import pandas as pd import SimpleITK as sitk from unetr_pp.evaluation.metrics import ConfusionMatrix, ALL_METRICS from batchgenerators.utilities.file_and_folder_operations import save_json, subfiles, join from collections import OrderedDict class Evaluator: """Object that holds test and reference segmentations with label information and computes a number of metrics on the two. 'labels' must either be an iterable of numeric values (or tuples thereof) or a dictionary with string names and numeric values. """ default_metrics = [ "False Positive Rate", "Dice", "Jaccard", "Precision", "Recall", "Accuracy", "False Omission Rate", "Negative Predictive Value", "False Negative Rate", "True Negative Rate", "False Discovery Rate", "Total Positives Test", "Total Positives Reference" ] default_advanced_metrics = [ #"Hausdorff Distance", "Hausdorff Distance 95", #"Avg. Surface Distance", #"Avg. Symmetric Surface Distance" ] def __init__(self, test=None, reference=None, labels=None, metrics=None, advanced_metrics=None, nan_for_nonexisting=True): self.test = None self.reference = None self.confusion_matrix = ConfusionMatrix() self.labels = None self.nan_for_nonexisting = nan_for_nonexisting self.result = None self.metrics = [] if metrics is None: for m in self.default_metrics: self.metrics.append(m) else: for m in metrics: self.metrics.append(m) self.advanced_metrics = [] if advanced_metrics is None: for m in self.default_advanced_metrics: self.advanced_metrics.append(m) else: for m in advanced_metrics: self.advanced_metrics.append(m) self.set_reference(reference) self.set_test(test) if labels is not None: self.set_labels(labels) else: if test is not None and reference is not None: self.construct_labels() def set_test(self, test): """Set the test segmentation.""" self.test = test def set_reference(self, reference): """Set the reference segmentation.""" self.reference = reference def set_labels(self, labels): """Set the labels. :param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels will only have names if you pass a dictionary""" if isinstance(labels, dict): self.labels = collections.OrderedDict(labels) elif isinstance(labels, set): self.labels = list(labels) elif isinstance(labels, np.ndarray): self.labels = [i for i in labels] elif isinstance(labels, (list, tuple)): self.labels = labels else: raise TypeError("Can only handle dict, list, tuple, set & numpy array, but input is of type {}".format(type(labels))) def construct_labels(self): """Construct label set from unique entries in segmentations.""" if self.test is None and self.reference is None: raise ValueError("No test or reference segmentations.") elif self.test is None: labels = np.unique(self.reference) else: labels = np.union1d(np.unique(self.test), np.unique(self.reference)) self.labels = list(map(lambda x: int(x), labels)) def set_metrics(self, metrics): """Set evaluation metrics""" if isinstance(metrics, set): self.metrics = list(metrics) elif isinstance(metrics, (list, tuple, np.ndarray)): self.metrics = metrics else: raise TypeError("Can only handle list, tuple, set & numpy array, but input is of type {}".format(type(metrics))) def add_metric(self, metric): if metric not in self.metrics: self.metrics.append(metric) def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs): """Compute metrics for segmentations.""" if test is not None: self.set_test(test) if reference is not None: self.set_reference(reference) if self.test is None or self.reference is None: raise ValueError("Need both test and reference segmentations.") if self.labels is None: self.construct_labels() self.metrics.sort() # get functions for evaluation # somewhat convoluted, but allows users to define additonal metrics # on the fly, e.g. inside an IPython console _funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics} frames = inspect.getouterframes(inspect.currentframe()) for metric in self.metrics: for f in frames: if metric in f[0].f_locals: _funcs[metric] = f[0].f_locals[metric] break else: if metric in _funcs: continue else: raise NotImplementedError( "Metric {} not implemented.".format(metric)) # get results self.result = OrderedDict() eval_metrics = self.metrics if advanced: eval_metrics += self.advanced_metrics if isinstance(self.labels, dict): for label, name in self.labels.items(): k = str(name) self.result[k] = OrderedDict() if not hasattr(label, "__iter__"): self.confusion_matrix.set_test(self.test == label) self.confusion_matrix.set_reference(self.reference == label) else: current_test = 0 current_reference = 0 for l in label: current_test += (self.test == l) current_reference += (self.reference == l) self.confusion_matrix.set_test(current_test) self.confusion_matrix.set_reference(current_reference) for metric in eval_metrics: self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix, nan_for_nonexisting=self.nan_for_nonexisting, **metric_kwargs) else: for i, l in enumerate(self.labels): k = str(l) self.result[k] = OrderedDict() self.confusion_matrix.set_test(self.test == l) self.confusion_matrix.set_reference(self.reference == l) for metric in eval_metrics: self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix, nan_for_nonexisting=self.nan_for_nonexisting, **metric_kwargs) return self.result def to_dict(self): if self.result is None: self.evaluate() return self.result def to_array(self): """Return result as numpy array (labels x metrics).""" if self.result is None: self.evaluate result_metrics = sorted(self.result[list(self.result.keys())[0]].keys()) a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32) if isinstance(self.labels, dict): for i, label in enumerate(self.labels.keys()): for j, metric in enumerate(result_metrics): a[i][j] = self.result[self.labels[label]][metric] else: for i, label in enumerate(self.labels): for j, metric in enumerate(result_metrics): a[i][j] = self.result[label][metric] return a def to_pandas(self): """Return result as pandas DataFrame.""" a = self.to_array() if isinstance(self.labels, dict): labels = list(self.labels.values()) else: labels = self.labels result_metrics = sorted(self.result[list(self.result.keys())[0]].keys()) return pd.DataFrame(a, index=labels, columns=result_metrics) class NiftiEvaluator(Evaluator): def __init__(self, *args, **kwargs): self.test_nifti = None self.reference_nifti = None super(NiftiEvaluator, self).__init__(*args, **kwargs) def set_test(self, test): """Set the test segmentation.""" if test is not None: self.test_nifti = sitk.ReadImage(test) super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti)) else: self.test_nifti = None super(NiftiEvaluator, self).set_test(test) def set_reference(self, reference): """Set the reference segmentation.""" if reference is not None: self.reference_nifti = sitk.ReadImage(reference) super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti)) else: self.reference_nifti = None super(NiftiEvaluator, self).set_reference(reference) def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs): if voxel_spacing is None: voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1] metric_kwargs["voxel_spacing"] = voxel_spacing return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs) def run_evaluation(args): test, ref, evaluator, metric_kwargs = args # evaluate evaluator.set_test(test) evaluator.set_reference(ref) if evaluator.labels is None: evaluator.construct_labels() current_scores = evaluator.evaluate(**metric_kwargs) if type(test) == str: current_scores["test"] = test if type(ref) == str: current_scores["reference"] = ref return current_scores def aggregate_scores(test_ref_pairs, evaluator=NiftiEvaluator, labels=None, nanmean=True, json_output_file=None, json_name="", json_description="", json_author="Fabian", json_task="", num_threads=2, **metric_kwargs): """ test = predicted image :param test_ref_pairs: :param evaluator: :param labels: must be a dict of int-> str or a list of int :param nanmean: :param json_output_file: :param json_name: :param json_description: :param json_author: :param json_task: :param metric_kwargs: :return: """ if type(evaluator) == type: evaluator = evaluator() if labels is not None: evaluator.set_labels(labels) all_scores = OrderedDict() all_scores["all"] = [] all_scores["mean"] = OrderedDict() test = [i[0] for i in test_ref_pairs] ref = [i[1] for i in test_ref_pairs] p = Pool(num_threads) all_res = p.map(run_evaluation, zip(test, ref, [evaluator]*len(ref), [metric_kwargs]*len(ref))) p.close() p.join() for i in range(len(all_res)): all_scores["all"].append(all_res[i]) # append score list for mean for label, score_dict in all_res[i].items(): if label in ("test", "reference"): continue if label not in all_scores["mean"]: all_scores["mean"][label] = OrderedDict() for score, value in score_dict.items(): if score not in all_scores["mean"][label]: all_scores["mean"][label][score] = [] all_scores["mean"][label][score].append(value) for label in all_scores["mean"]: for score in all_scores["mean"][label]: if nanmean: all_scores["mean"][label][score] = float(np.nanmean(all_scores["mean"][label][score])) else: all_scores["mean"][label][score] = float(np.mean(all_scores["mean"][label][score])) # save to file if desired # we create a hopefully unique id by hashing the entire output dictionary if json_output_file is not None: json_dict = OrderedDict() json_dict["name"] = json_name json_dict["description"] = json_description timestamp = datetime.today() json_dict["timestamp"] = str(timestamp) json_dict["task"] = json_task json_dict["author"] = json_author json_dict["results"] = all_scores json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12] save_json(json_dict, json_output_file) return all_scores def aggregate_scores_for_experiment(score_file, labels=None, metrics=Evaluator.default_metrics, nanmean=True, json_output_file=None, json_name="", json_description="", json_author="Fabian", json_task=""): scores = np.load(score_file) scores_mean = scores.mean(0) if labels is None: labels = list(map(str, range(scores.shape[1]))) results = [] results_mean = OrderedDict() for i in range(scores.shape[0]): results.append(OrderedDict()) for l, label in enumerate(labels): results[-1][label] = OrderedDict() results_mean[label] = OrderedDict() for m, metric in enumerate(metrics): results[-1][label][metric] = float(scores[i][l][m]) results_mean[label][metric] = float(scores_mean[l][m]) json_dict = OrderedDict() json_dict["name"] = json_name json_dict["description"] = json_description timestamp = datetime.today() json_dict["timestamp"] = str(timestamp) json_dict["task"] = json_task json_dict["author"] = json_author json_dict["results"] = {"all": results, "mean": results_mean} json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12] if json_output_file is not None: json_output_file = open(json_output_file, "w") json.dump(json_dict, json_output_file, indent=4, separators=(",", ": ")) json_output_file.close() return json_dict def evaluate_folder(folder_with_gts: str, folder_with_predictions: str, labels: tuple, **metric_kwargs): """ writes a summary.json to folder_with_predictions :param folder_with_gts: folder where the ground truth segmentations are saved. Must be nifti files. :param folder_with_predictions: folder where the predicted segmentations are saved. Must be nifti files. :param labels: tuple of int with the labels in the dataset. For example (0, 1, 2, 3) for Task001_BrainTumour. :return: """ files_gt = subfiles(folder_with_gts, suffix=".nii.gz", join=False) files_pred = subfiles(folder_with_predictions, suffix=".nii.gz", join=False) assert all([i in files_pred for i in files_gt]), "files missing in folder_with_predictions" assert all([i in files_gt for i in files_pred]), "files missing in folder_with_gts" test_ref_pairs = [(join(folder_with_predictions, i), join(folder_with_gts, i)) for i in files_pred] res = aggregate_scores(test_ref_pairs, json_output_file=join(folder_with_predictions, "summary.json"), num_threads=8, labels=labels, **metric_kwargs) return res def nnformer_evaluate_folder(): import argparse parser = argparse.ArgumentParser("Evaluates the segmentations located in the folder pred. Output of this script is " "a json file. At the very bottom of the json file is going to be a 'mean' " "entry with averages metrics across all cases") parser.add_argument('-ref', required=True, type=str, help="Folder containing the reference segmentations in nifti " "format.") parser.add_argument('-pred', required=True, type=str, help="Folder containing the predicted segmentations in nifti " "format. File names must match between the folders!") parser.add_argument('-l', nargs='+', type=int, required=True, help="List of label IDs (integer values) that should " "be evaluated. Best practice is to use all int " "values present in the dataset, so for example " "for LiTS the labels are 0: background, 1: " "liver, 2: tumor. So this argument " "should be -l 1 2. You can if you want also " "evaluate the background label (0) but in " "this case that would not gie any useful " "information.") args = parser.parse_args() return evaluate_folder(args.ref, args.pred, args.l) ================================================ FILE: unetr_pp/evaluation/metrics.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from medpy import metric def assert_shape(test, reference): assert test.shape == reference.shape, "Shape mismatch: {} and {}".format( test.shape, reference.shape) class ConfusionMatrix: def __init__(self, test=None, reference=None): self.tp = None self.fp = None self.tn = None self.fn = None self.size = None self.reference_empty = None self.reference_full = None self.test_empty = None self.test_full = None self.set_reference(reference) self.set_test(test) def set_test(self, test): self.test = test self.reset() def set_reference(self, reference): self.reference = reference self.reset() def reset(self): self.tp = None self.fp = None self.tn = None self.fn = None self.size = None self.test_empty = None self.test_full = None self.reference_empty = None self.reference_full = None def compute(self): if self.test is None or self.reference is None: raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.") assert_shape(self.test, self.reference) self.tp = int(((self.test != 0) * (self.reference != 0)).sum()) self.fp = int(((self.test != 0) * (self.reference == 0)).sum()) self.tn = int(((self.test == 0) * (self.reference == 0)).sum()) self.fn = int(((self.test == 0) * (self.reference != 0)).sum()) self.size = int(np.prod(self.reference.shape, dtype=np.int64)) self.test_empty = not np.any(self.test) self.test_full = np.all(self.test) self.reference_empty = not np.any(self.reference) self.reference_full = np.all(self.reference) def get_matrix(self): for entry in (self.tp, self.fp, self.tn, self.fn): if entry is None: self.compute() break return self.tp, self.fp, self.tn, self.fn def get_size(self): if self.size is None: self.compute() return self.size def get_existence(self): for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full): if case is None: self.compute() break return self.test_empty, self.test_full, self.reference_empty, self.reference_full def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """2TP / (2TP + FP + FN)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty and reference_empty: if nan_for_nonexisting: return float("NaN") else: return 0. return float(2. * tp / (2 * tp + fp + fn)) def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TP / (TP + FP + FN)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty and reference_empty: if nan_for_nonexisting: return float("NaN") else: return 0. return float(tp / (tp + fp + fn)) def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TP / (TP + FP)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty: if nan_for_nonexisting: return float("NaN") else: return 0. return float(tp / (tp + fp)) def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TP / (TP + FN)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if reference_empty: if nan_for_nonexisting: return float("NaN") else: return 0. return float(tp / (tp + fn)) def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TP / (TP + FN)""" return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs) def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TN / (TN + FP)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if reference_full: if nan_for_nonexisting: return float("NaN") else: return 0. return float(tn / (tn + fp)) def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs): """(TP + TN) / (TP + FP + FN + TN)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() return float((tp + tn) / (tp + fp + tn + fn)) def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs): """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)""" precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting) recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting) return (1 + beta*beta) * precision_ * recall_ /\ ((beta*beta * precision_) + recall_) def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """FP / (FP + TN)""" return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting) def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """FN / (TN + FN)""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_full: if nan_for_nonexisting: return float("NaN") else: return 0. return float(fn / (fn + tn)) def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """FN / (TP + FN)""" return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting) def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TN / (TN + FP)""" return specificity(test, reference, confusion_matrix, nan_for_nonexisting) def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """FP / (TP + FP)""" return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting) def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs): """TN / (TN + FN)""" return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting) def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs): """TP + FP""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() return tp + fp def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs): """TN + FN""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() return tn + fn def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs): """TP + FN""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() return tp + fn def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs): """TN + FP""" if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) tp, fp, tn, fn = confusion_matrix.get_matrix() return tn + fp def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs): if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty or test_full or reference_empty or reference_full: if nan_for_nonexisting: return float("NaN") else: return 0 test, reference = confusion_matrix.test, confusion_matrix.reference return metric.hd(test, reference, voxel_spacing, connectivity) def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs): if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty or test_full or reference_empty or reference_full: if nan_for_nonexisting: return float("NaN") else: return 0 test, reference = confusion_matrix.test, confusion_matrix.reference return metric.hd95(test, reference, voxel_spacing, connectivity) def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs): if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty or test_full or reference_empty or reference_full: if nan_for_nonexisting: return float("NaN") else: return 0 test, reference = confusion_matrix.test, confusion_matrix.reference return metric.asd(test, reference, voxel_spacing, connectivity) def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs): if confusion_matrix is None: confusion_matrix = ConfusionMatrix(test, reference) test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() if test_empty or test_full or reference_empty or reference_full: if nan_for_nonexisting: return float("NaN") else: return 0 test, reference = confusion_matrix.test, confusion_matrix.reference return metric.assd(test, reference, voxel_spacing, connectivity) ALL_METRICS = { "False Positive Rate": false_positive_rate, "Dice": dice, "Jaccard": jaccard, "Hausdorff Distance": hausdorff_distance, "Hausdorff Distance 95": hausdorff_distance_95, "Precision": precision, "Recall": recall, "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric, "Avg. Surface Distance": avg_surface_distance, "Accuracy": accuracy, "False Omission Rate": false_omission_rate, "Negative Predictive Value": negative_predictive_value, "False Negative Rate": false_negative_rate, "True Negative Rate": true_negative_rate, "False Discovery Rate": false_discovery_rate, "Total Positives Test": total_positives_test, "Total Negatives Test": total_negatives_test, "Total Positives Reference": total_positives_reference, "total Negatives Reference": total_negatives_reference } ================================================ FILE: unetr_pp/evaluation/model_selection/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize2 from unetr_pp.paths import network_training_output_dir from batchgenerators.utilities.file_and_folder_operations import * if __name__ == "__main__": summary_output_folder = join(network_training_output_dir, "summary_jsons_fold0_new") maybe_mkdir_p(summary_output_folder) summarize2(['all'], output_dir=summary_output_folder, folds=(0,)) results_csv = join(network_training_output_dir, "summary_fold0.csv") summary_files = subfiles(summary_output_folder, suffix='.json', join=False) with open(results_csv, 'w') as f: for s in summary_files: if s.find("ensemble") == -1: task, network, trainer, plans, validation_folder, folds = s.split("__") else: n1, n2 = s.split("--") n1 = n1[n1.find("ensemble_") + len("ensemble_") :] task = s.split("__")[0] network = "ensemble" trainer = n1 plans = n2 validation_folder = "none" folds = folds[:-len('.json')] results = load_json(join(summary_output_folder, s)) results_mean = results['results']['mean']['mean']['Dice'] results_median = results['results']['median']['mean']['Dice'] f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task, network, trainer, validation_folder, plans, results_mean, results_median)) summary_output_folder = join(network_training_output_dir, "summary_jsons_new") maybe_mkdir_p(summary_output_folder) summarize2(['all'], output_dir=summary_output_folder) results_csv = join(network_training_output_dir, "summary_allFolds.csv") summary_files = subfiles(summary_output_folder, suffix='.json', join=False) with open(results_csv, 'w') as f: for s in summary_files: if s.find("ensemble") == -1: task, network, trainer, plans, validation_folder, folds = s.split("__") else: n1, n2 = s.split("--") n1 = n1[n1.find("ensemble_") + len("ensemble_") :] task = s.split("__")[0] network = "ensemble" trainer = n1 plans = n2 validation_folder = "none" folds = folds[:-len('.json')] results = load_json(join(summary_output_folder, s)) results_mean = results['results']['mean']['mean']['Dice'] results_median = results['results']['median']['mean']['Dice'] f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task, network, trainer, validation_folder, plans, results_mean, results_median)) ================================================ FILE: unetr_pp/evaluation/model_selection/ensemble.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from multiprocessing.pool import Pool import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir from unetr_pp.postprocessing.connected_components import determine_postprocessing def merge(args): file1, file2, properties_file, out_file = args if not isfile(out_file): res1 = np.load(file1)['softmax'] res2 = np.load(file2)['softmax'] props = load_pickle(properties_file) mn = np.mean((res1, res2), 0) # Softmax probabilities are already at target spacing so this will not do any resampling (resampling parameters # don't matter here) save_segmentation_nifti_from_softmax(mn, out_file, props, 3, None, None, None, force_separate_z=None, interpolation_order_z=0) def ensemble(training_output_folder1, training_output_folder2, output_folder, task, validation_folder, folds, allow_ensembling: bool = True): print("\nEnsembling folders\n", training_output_folder1, "\n", training_output_folder2) output_folder_base = output_folder output_folder = join(output_folder_base, "ensembled_raw") # only_keep_largest_connected_component is the same for all stages dataset_directory = join(preprocessing_output_dir, task) plans = load_pickle(join(training_output_folder1, "plans.pkl")) # we need this only for the labels files1 = [] files2 = [] property_files = [] out_files = [] gt_segmentations = [] folder_with_gt_segs = join(dataset_directory, "gt_segmentations") # in the correct shape and we need the original geometry to restore the niftis for f in folds: validation_folder_net1 = join(training_output_folder1, "fold_%d" % f, validation_folder) validation_folder_net2 = join(training_output_folder2, "fold_%d" % f, validation_folder) if not isdir(validation_folder_net1): raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1) if not isdir(validation_folder_net2): raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2) # we need to ensure the validation was successful. We can verify this via the presence of the summary.json file if not isfile(join(validation_folder_net1, 'summary.json')): raise AssertionError("Validation directory incomplete: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1) if not isfile(join(validation_folder_net2, 'summary.json')): raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnFormer_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2) patient_identifiers1_npz = [i[:-4] for i in subfiles(validation_folder_net1, False, None, 'npz', True)] patient_identifiers2_npz = [i[:-4] for i in subfiles(validation_folder_net2, False, None, 'npz', True)] # we don't do postprocessing anymore so there should not be any of that noPostProcess patient_identifiers1_nii = [i[:-7] for i in subfiles(validation_folder_net1, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')] patient_identifiers2_nii = [i[:-7] for i in subfiles(validation_folder_net2, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')] if not all([i in patient_identifiers1_npz for i in patient_identifiers1_nii]): raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net1)) if not all([i in patient_identifiers2_npz for i in patient_identifiers2_nii]): raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net2)) patient_identifiers1_npz.sort() patient_identifiers2_npz.sort() assert all([i == j for i, j in zip(patient_identifiers1_npz, patient_identifiers2_npz)]), "npz filenames do not match. This should not happen." maybe_mkdir_p(output_folder) for p in patient_identifiers1_npz: files1.append(join(validation_folder_net1, p + '.npz')) files2.append(join(validation_folder_net2, p + '.npz')) property_files.append(join(validation_folder_net1, p) + ".pkl") out_files.append(join(output_folder, p + ".nii.gz")) gt_segmentations.append(join(folder_with_gt_segs, p + ".nii.gz")) p = Pool(default_num_threads) p.map(merge, zip(files1, files2, property_files, out_files)) p.close() p.join() if not isfile(join(output_folder, "summary.json")) and len(out_files) > 0: aggregate_scores(tuple(zip(out_files, gt_segmentations)), labels=plans['all_classes'], json_output_file=join(output_folder, "summary.json"), json_task=task, json_name=task + "__" + output_folder_base.split("/")[-1], num_threads=default_num_threads) if allow_ensembling and not isfile(join(output_folder_base, "postprocessing.json")): # now lets also look at postprocessing. We cannot just take what we determined in cross-validation and apply it # here because things may have changed and may also be too inconsistent between the two networks determine_postprocessing(output_folder_base, folder_with_gt_segs, "ensembled_raw", "temp", "ensembled_postprocessed", default_num_threads, dice_threshold=0) out_dir_all_json = join(network_training_output_dir, "summary_jsons") json_out = load_json(join(output_folder_base, "ensembled_postprocessed", "summary.json")) json_out["experiment_name"] = output_folder_base.split("/")[-1] save_json(json_out, join(output_folder_base, "ensembled_postprocessed", "summary.json")) maybe_mkdir_p(out_dir_all_json) shutil.copy(join(output_folder_base, "ensembled_postprocessed", "summary.json"), join(out_dir_all_json, "%s__%s.json" % (task, output_folder_base.split("/")[-1]))) ================================================ FILE: unetr_pp/evaluation/model_selection/figure_out_what_to_submit.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from itertools import combinations import unetr_pp from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.evaluation.add_mean_dice_to_json import foreground_mean from unetr_pp.evaluation.evaluator import evaluate_folder from unetr_pp.evaluation.model_selection.ensemble import ensemble from unetr_pp.paths import network_training_output_dir import numpy as np from subprocess import call from unetr_pp.postprocessing.consolidate_postprocessing import consolidate_folds, collect_cv_niftis from unetr_pp.utilities.folder_names import get_output_folder_name from unetr_pp.paths import default_cascade_trainer, default_trainer, default_plans_identifier def find_task_name(folder, task_id): candidates = subdirs(folder, prefix="Task%03.0d_" % task_id, join=False) assert len(candidates) > 0, "no candidate for Task id %d found in folder %s" % (task_id, folder) assert len(candidates) == 1, "more than one candidate for Task id %d found in folder %s" % (task_id, folder) return candidates[0] def get_mean_foreground_dice(json_file): results = load_json(json_file) return get_foreground_mean(results) def get_foreground_mean(results): results_mean = results['results']['mean'] dice_scores = [results_mean[i]['Dice'] for i in results_mean.keys() if i != "0" and i != 'mean'] return np.mean(dice_scores) def main(): import argparse parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold " "cross-validation. Running this script requires all models to have been run " "already. This script will summarize the results of the five folds of all " "models in one json each for easy interpretability") parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']) parser.add_argument("-t", '--task_ids', nargs="+", required=True) parser.add_argument("-tr", type=str, required=False, default=default_trainer, help="nnFormerTrainer class. Default: %s" % default_trainer) parser.add_argument("-ctr", type=str, required=False, default=default_cascade_trainer, help="nnFormerTrainer class for cascade model. Default: %s" % default_cascade_trainer) parser.add_argument("-pl", type=str, required=False, default=default_plans_identifier, help="plans name, Default: %s" % default_plans_identifier) parser.add_argument('-f', '--folds', nargs='+', default=(0, 1, 2, 3, 4), help="Use this if you have non-standard " "folds. Experienced users only.") parser.add_argument('--disable_ensembling', required=False, default=False, action='store_true', help='Set this flag to disable the use of ensembling. This will find the best single ' 'configuration for each task.') parser.add_argument("--disable_postprocessing", required=False, default=False, action="store_true", help="Set this flag if you want to disable the use of postprocessing") args = parser.parse_args() tasks = [int(i) for i in args.task_ids] models = args.models tr = args.tr trc = args.ctr pl = args.pl disable_ensembling = args.disable_ensembling disable_postprocessing = args.disable_postprocessing folds = tuple(int(i) for i in args.folds) validation_folder = "validation_raw" # this script now acts independently from the summary jsons. That was unnecessary id_task_mapping = {} for t in tasks: # first collect pure model performance results = {} all_results = {} valid_models = [] for m in models: if m == "3d_cascade_fullres": trainer = trc else: trainer = tr if t not in id_task_mapping.keys(): task_name = find_task_name(get_output_folder_name(m), t) id_task_mapping[t] = task_name output_folder = get_output_folder_name(m, id_task_mapping[t], trainer, pl) if not isdir(output_folder): raise RuntimeError("Output folder for model %s is missing, expected: %s" % (m, output_folder)) if disable_postprocessing: # we need to collect the predicted niftis from the 5-fold cv and evaluate them against the ground truth cv_niftis_folder = join(output_folder, 'cv_niftis_raw') if not isfile(join(cv_niftis_folder, 'summary.json')): print(t, m, ': collecting niftis from 5-fold cv') if isdir(cv_niftis_folder): shutil.rmtree(cv_niftis_folder) collect_cv_niftis(output_folder, cv_niftis_folder, validation_folder, folds) niftis_gt = subfiles(join(output_folder, "gt_niftis"), suffix='.nii.gz', join=False) niftis_cv = subfiles(cv_niftis_folder, suffix='.nii.gz', join=False) if not all([i in niftis_gt for i in niftis_cv]): raise AssertionError("It does not seem like you trained all the folds! Train " \ "all folds first! There are %d gt niftis in %s but only " \ "%d predicted niftis in %s" % (len(niftis_gt), niftis_gt, len(niftis_cv), niftis_cv)) # load a summary file so that we can know what class labels to expect summary_fold0 = load_json(join(output_folder, "fold_%d" % folds[0], validation_folder, "summary.json"))['results']['mean'] # read classes from summary.json classes = tuple((int(i) for i in summary_fold0.keys())) # evaluate the cv niftis print(t, m, ': evaluating 5-fold cv results') evaluate_folder(join(output_folder, "gt_niftis"), cv_niftis_folder, classes) else: postprocessing_json = join(output_folder, "postprocessing.json") cv_niftis_folder = join(output_folder, "cv_niftis_raw") # we need cv_niftis_postprocessed to know the single model performance. And we need the # postprocessing_json. If either of those is missing, rerun consolidate_folds if not isfile(postprocessing_json) or not isdir(cv_niftis_folder): print("running missing postprocessing for %s and model %s" % (id_task_mapping[t], m)) consolidate_folds(output_folder, folds=folds) assert isfile(postprocessing_json), "Postprocessing json missing, expected: %s" % postprocessing_json assert isdir(cv_niftis_folder), "Folder with niftis from CV missing, expected: %s" % cv_niftis_folder # obtain mean foreground dice summary_file = join(cv_niftis_folder, "summary.json") results[m] = get_mean_foreground_dice(summary_file) foreground_mean(summary_file) all_results[m] = load_json(summary_file)['results']['mean'] valid_models.append(m) if not disable_ensembling: # now run ensembling and add ensembling to results print("\nI will now ensemble combinations of the following models:\n", valid_models) if len(valid_models) > 1: for m1, m2 in combinations(valid_models, 2): trainer_m1 = trc if m1 == "3d_cascade_fullres" else tr trainer_m2 = trc if m2 == "3d_cascade_fullres" else tr ensemble_name = "ensemble_" + m1 + "__" + trainer_m1 + "__" + pl + "--" + m2 + "__" + trainer_m2 + "__" + pl output_folder_base = join(network_training_output_dir, "ensembles", id_task_mapping[t], ensemble_name) maybe_mkdir_p(output_folder_base) network1_folder = get_output_folder_name(m1, id_task_mapping[t], trainer_m1, pl) network2_folder = get_output_folder_name(m2, id_task_mapping[t], trainer_m2, pl) print("ensembling", network1_folder, network2_folder) ensemble(network1_folder, network2_folder, output_folder_base, id_task_mapping[t], validation_folder, folds, allow_ensembling=not disable_postprocessing) # ensembling will automatically do postprocessingget_foreground_mean # now get result of ensemble results[ensemble_name] = get_mean_foreground_dice(join(output_folder_base, "ensembled_raw", "summary.json")) summary_file = join(output_folder_base, "ensembled_raw", "summary.json") foreground_mean(summary_file) all_results[ensemble_name] = load_json(summary_file)['results']['mean'] # now print all mean foreground dice and highlight the best foreground_dices = list(results.values()) best = np.max(foreground_dices) for k, v in results.items(): print(k, v) predict_str = "" best_model = None for k, v in results.items(): if v == best: print("%s submit model %s" % (id_task_mapping[t], k), v) best_model = k print("\nHere is how you should predict test cases. Run in sequential order and replace all input and output folder names with your personalized ones\n") if k.startswith("ensemble"): tmp = k[len("ensemble_"):] model1, model2 = tmp.split("--") m1, t1, pl1 = model1.split("__") m2, t2, pl2 = model2.split("__") predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + m1 + " -p " + pl + " -t " + \ id_task_mapping[t] + "\n" predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL2 -tr " + tr + " -ctr " + trc + " -m " + m2 + " -p " + pl + " -t " + \ id_task_mapping[t] + "\n" if not disable_postprocessing: predict_str += "nnFormer_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER -pp " + join(network_training_output_dir, "ensembles", id_task_mapping[t], k, "postprocessing.json") + "\n" else: predict_str += "nnFormer_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER\n" else: predict_str += "nnFormer_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + k + " -p " + pl + " -t " + \ id_task_mapping[t] + "\n" print(predict_str) summary_folder = join(network_training_output_dir, "ensembles", id_task_mapping[t]) maybe_mkdir_p(summary_folder) with open(join(summary_folder, "prediction_commands.txt"), 'w') as f: f.write(predict_str) num_classes = len([i for i in all_results[best_model].keys() if i != 'mean' and i != '0']) with open(join(summary_folder, "summary.csv"), 'w') as f: f.write("model") for c in range(1, num_classes + 1): f.write(",class%d" % c) f.write(",average") f.write("\n") for m in all_results.keys(): f.write(m) for c in range(1, num_classes + 1): f.write(",%01.4f" % all_results[m][str(c)]["Dice"]) f.write(",%01.4f" % all_results[m]['mean']["Dice"]) f.write("\n") if __name__ == "__main__": main() ================================================ FILE: unetr_pp/evaluation/model_selection/rank_candidates.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import network_training_output_dir if __name__ == "__main__": # run collect_all_fold0_results_and_summarize_in_one_csv.py first summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new") output_file = join(network_training_output_dir, "summary.csv") folds = (0, ) folds_str = "" for f in folds: folds_str += str(f) plans = "nnFormerPlans" overwrite_plans = { 'nnFormerTrainerV2_2': ["nnFormerPlans", "nnFormerPlansisoPatchesInVoxels"], # r 'nnFormerTrainerV2': ["nnFormerPlansnonCT", "nnFormerPlansCT2", "nnFormerPlansallConv3x3", "nnFormerPlansfixedisoPatchesInVoxels", "nnFormerPlanstargetSpacingForAnisoAxis", "nnFormerPlanspoolBasedOnSpacing", "nnFormerPlansfixedisoPatchesInmm", "nnFormerPlansv2.1"], 'nnFormerTrainerV2_warmup': ["nnFormerPlans", "nnFormerPlansv2.1", "nnFormerPlansv2.1_big", "nnFormerPlansv2.1_verybig"], 'nnFormerTrainerV2_cycleAtEnd': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_cycleAtEnd2': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_reduceMomentumDuringTraining': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_graduallyTransitionFromCEToDice': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_independentScalePerAxis': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_Mish': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_Ranger_lr3en4': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_GN': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_momentum098': ["nnFormerPlans", "nnFormerPlansv2.1"], 'nnFormerTrainerV2_momentum09': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_DP': ["nnFormerPlansv2.1_verybig"], 'nnFormerTrainerV2_DDP': ["nnFormerPlansv2.1_verybig"], 'nnFormerTrainerV2_FRN': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_resample33': ["nnFormerPlansv2.3"], 'nnFormerTrainerV2_O2': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ResencUNet': ["nnFormerPlans_FabiansResUNet_v2.1"], 'nnFormerTrainerV2_DA2': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_allConv3x3': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ForceBD': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ForceSD': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_LReLU_slope_2en1': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_lReLU_convReLUIN': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ReLU': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ReLU_biasInSegOutput': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_ReLU_convReLUIN': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_lReLU_biasInSegOutput': ["nnFormerPlansv2.1"], #'nnFormerTrainerV2_Loss_MCC': ["nnFormerPlansv2.1"], #'nnFormerTrainerV2_Loss_MCCnoBG': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_Loss_DicewithBG': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_Loss_Dice_LR1en3': ["nnFormerPlansv2.1"], 'nnFormerTrainerV2_Loss_Dice': ["nnFormerPlans", "nnFormerPlansv2.1"], 'nnFormerTrainerV2_Loss_DicewithBG_LR1en3': ["nnFormerPlansv2.1"], # 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], # 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], # 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], # 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], # 'nnFormerTrainerV2_fp32': ["nnFormerPlansv2.1"], } trainers = ['nnFormerTrainer'] + ['nnFormerTrainerNewCandidate%d' % i for i in range(1, 28)] + [ 'nnFormerTrainerNewCandidate24_2', 'nnFormerTrainerNewCandidate24_3', 'nnFormerTrainerNewCandidate26_2', 'nnFormerTrainerNewCandidate27_2', 'nnFormerTrainerNewCandidate23_always3DDA', 'nnFormerTrainerNewCandidate23_corrInit', 'nnFormerTrainerNewCandidate23_noOversampling', 'nnFormerTrainerNewCandidate23_softDS', 'nnFormerTrainerNewCandidate23_softDS2', 'nnFormerTrainerNewCandidate23_softDS3', 'nnFormerTrainerNewCandidate23_softDS4', 'nnFormerTrainerNewCandidate23_2_fp16', 'nnFormerTrainerNewCandidate23_2', 'nnFormerTrainerVer2', 'nnFormerTrainerV2_2', 'nnFormerTrainerV2_3', 'nnFormerTrainerV2_3_CE_GDL', 'nnFormerTrainerV2_3_dcTopk10', 'nnFormerTrainerV2_3_dcTopk20', 'nnFormerTrainerV2_3_fp16', 'nnFormerTrainerV2_3_softDS4', 'nnFormerTrainerV2_3_softDS4_clean', 'nnFormerTrainerV2_3_softDS4_clean_improvedDA', 'nnFormerTrainerV2_3_softDS4_clean_improvedDA_newElDef', 'nnFormerTrainerV2_3_softDS4_radam', 'nnFormerTrainerV2_3_softDS4_radam_lowerLR', 'nnFormerTrainerV2_2_schedule', 'nnFormerTrainerV2_2_schedule2', 'nnFormerTrainerV2_2_clean', 'nnFormerTrainerV2_2_clean_improvedDA_newElDef', 'nnFormerTrainerV2_2_fixes', # running 'nnFormerTrainerV2_BN', # running 'nnFormerTrainerV2_noDeepSupervision', # running 'nnFormerTrainerV2_softDeepSupervision', # running 'nnFormerTrainerV2_noDataAugmentation', # running 'nnFormerTrainerV2_Loss_CE', # running 'nnFormerTrainerV2_Loss_CEGDL', 'nnFormerTrainerV2_Loss_Dice', 'nnFormerTrainerV2_Loss_DiceTopK10', 'nnFormerTrainerV2_Loss_TopK10', 'nnFormerTrainerV2_Adam', # running 'nnFormerTrainerV2_Adam_nnFormerTrainerlr', # running 'nnFormerTrainerV2_SGD_ReduceOnPlateau', # running 'nnFormerTrainerV2_SGD_lr1en1', # running 'nnFormerTrainerV2_SGD_lr1en3', # running 'nnFormerTrainerV2_fixedNonlin', # running 'nnFormerTrainerV2_GeLU', # running 'nnFormerTrainerV2_3ConvPerStage', 'nnFormerTrainerV2_NoNormalization', 'nnFormerTrainerV2_Adam_ReduceOnPlateau', 'nnFormerTrainerV2_fp16', 'nnFormerTrainerV2', # see overwrite_plans 'nnFormerTrainerV2_noMirroring', 'nnFormerTrainerV2_momentum09', 'nnFormerTrainerV2_momentum095', 'nnFormerTrainerV2_momentum098', 'nnFormerTrainerV2_warmup', 'nnFormerTrainerV2_Loss_Dice_LR1en3', 'nnFormerTrainerV2_NoNormalization_lr1en3', 'nnFormerTrainerV2_Loss_Dice_squared', 'nnFormerTrainerV2_newElDef', 'nnFormerTrainerV2_fp32', 'nnFormerTrainerV2_cycleAtEnd', 'nnFormerTrainerV2_reduceMomentumDuringTraining', 'nnFormerTrainerV2_graduallyTransitionFromCEToDice', 'nnFormerTrainerV2_insaneDA', 'nnFormerTrainerV2_independentScalePerAxis', 'nnFormerTrainerV2_Mish', 'nnFormerTrainerV2_Ranger_lr3en4', 'nnFormerTrainerV2_cycleAtEnd2', 'nnFormerTrainerV2_GN', 'nnFormerTrainerV2_DP', 'nnFormerTrainerV2_FRN', 'nnFormerTrainerV2_resample33', 'nnFormerTrainerV2_O2', 'nnFormerTrainerV2_ResencUNet', 'nnFormerTrainerV2_DA2', 'nnFormerTrainerV2_allConv3x3', 'nnFormerTrainerV2_ForceBD', 'nnFormerTrainerV2_ForceSD', 'nnFormerTrainerV2_ReLU', 'nnFormerTrainerV2_LReLU_slope_2en1', 'nnFormerTrainerV2_lReLU_convReLUIN', 'nnFormerTrainerV2_ReLU_biasInSegOutput', 'nnFormerTrainerV2_ReLU_convReLUIN', 'nnFormerTrainerV2_lReLU_biasInSegOutput', 'nnFormerTrainerV2_Loss_DicewithBG_LR1en3', #'nnFormerTrainerV2_Loss_MCCnoBG', 'nnFormerTrainerV2_Loss_DicewithBG', # 'nnFormerTrainerV2_Loss_Dice_LR1en3', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', # 'nnFormerTrainerV2_Ranger_lr3en4', ] datasets = \ {"Task001_BrainTumour": ("3d_fullres", ), "Task002_Heart": ("3d_fullres",), #"Task024_Promise": ("3d_fullres",), #"Task027_ACDC": ("3d_fullres",), "Task003_Liver": ("3d_fullres", "3d_lowres"), "Task004_Hippocampus": ("3d_fullres",), "Task005_Prostate": ("3d_fullres",), "Task006_Lung": ("3d_fullres", "3d_lowres"), "Task007_Pancreas": ("3d_fullres", "3d_lowres"), "Task008_HepaticVessel": ("3d_fullres", "3d_lowres"), "Task009_Spleen": ("3d_fullres", "3d_lowres"), "Task010_Colon": ("3d_fullres", "3d_lowres"),} expected_validation_folder = "validation_raw" alternative_validation_folder = "validation" alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True" interested_in = "mean" result_per_dataset = {} for d in datasets: result_per_dataset[d] = {} for c in datasets[d]: result_per_dataset[d][c] = [] valid_trainers = [] all_trainers = [] with open(output_file, 'w') as f: f.write("trainer,") for t in datasets.keys(): s = t[4:7] for c in datasets[t]: s1 = s + "_" + c[3] f.write("%s," % s1) f.write("\n") for trainer in trainers: trainer_plans = [plans] if trainer in overwrite_plans.keys(): trainer_plans = overwrite_plans[trainer] result_per_dataset_here = {} for d in datasets: result_per_dataset_here[d] = {} for p in trainer_plans: name = "%s__%s" % (trainer, p) all_present = True all_trainers.append(name) f.write("%s," % name) for dataset in datasets.keys(): for configuration in datasets[dataset]: summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % ( dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str)) if not isfile(summary_file): all_present = False print(name, dataset, configuration, "has missing summary file") if isfile(summary_file): result = load_json(summary_file)['results'][interested_in]['mean']['Dice'] result_per_dataset_here[dataset][configuration] = result f.write("%02.4f," % result) else: f.write("NA,") result_per_dataset_here[dataset][configuration] = 0 f.write("\n") if True: valid_trainers.append(name) for d in datasets: for c in datasets[d]: result_per_dataset[d][c].append(result_per_dataset_here[d][c]) invalid_trainers = [i for i in all_trainers if i not in valid_trainers] num_valid = len(valid_trainers) num_datasets = len(datasets.keys()) # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two all_res = np.zeros((num_valid, num_datasets)) for j, d in enumerate(datasets.keys()): ks = list(result_per_dataset[d].keys()) tmp = result_per_dataset[d][ks[0]] for k in ks[1:]: for i in range(len(tmp)): tmp[i] = max(tmp[i], result_per_dataset[d][k][i]) all_res[:, j] = tmp ranks_arr = np.zeros_like(all_res) for d in range(ranks_arr.shape[1]): temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0 ranks = np.empty_like(temp) ranks[temp] = np.arange(len(temp)) ranks_arr[:, d] = ranks mn = np.mean(ranks_arr, 1) for i in np.argsort(mn): print(mn[i], valid_trainers[i]) print() print(valid_trainers[np.argmin(mn)]) ================================================ FILE: unetr_pp/evaluation/model_selection/rank_candidates_StructSeg.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import network_training_output_dir if __name__ == "__main__": # run collect_all_fold0_results_and_summarize_in_one_csv.py first summary_files_dir = join(network_training_output_dir, "summary_jsons_new") output_file = join(network_training_output_dir, "summary_structseg_5folds.csv") folds = (0, 1, 2, 3, 4) folds_str = "" for f in folds: folds_str += str(f) plans = "nnFormerPlans" overwrite_plans = { 'nnFormerTrainerV2_2': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_2_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_lessMomentum_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_2_structSeg_noMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_2_structSeg': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_lessMomentum_noMirror_structSeg': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror': ["nnFormerPlans", "nnFormerPlans_customClip"], # r 'nnFormerTrainerV2_FabiansResUNet_structSet': ["nnFormerPlans", "nnFormerPlans_customClip"], # r } trainers = ['nnFormerTrainer'] + [ 'nnFormerTrainerV2_2', 'nnFormerTrainerV2_lessMomentum_noMirror', 'nnFormerTrainerV2_2_noMirror', 'nnFormerTrainerV2_2_structSeg_noMirror', 'nnFormerTrainerV2_2_structSeg', 'nnFormerTrainerV2_lessMomentum_noMirror_structSeg', 'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder', 'nnFormerTrainerV2_FabiansResUNet_structSet_NoMirror', 'nnFormerTrainerV2_FabiansResUNet_structSet', ] datasets = \ {"Task049_StructSeg2019_Task1_HaN_OAR": ("3d_fullres", "3d_lowres", "2d"), "Task050_StructSeg2019_Task2_Naso_GTV": ("3d_fullres", "3d_lowres", "2d"), "Task051_StructSeg2019_Task3_Thoracic_OAR": ("3d_fullres", "3d_lowres", "2d"), "Task052_StructSeg2019_Task4_Lung_GTV": ("3d_fullres", "3d_lowres", "2d"), } expected_validation_folder = "validation_raw" alternative_validation_folder = "validation" alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True" interested_in = "mean" result_per_dataset = {} for d in datasets: result_per_dataset[d] = {} for c in datasets[d]: result_per_dataset[d][c] = [] valid_trainers = [] all_trainers = [] with open(output_file, 'w') as f: f.write("trainer,") for t in datasets.keys(): s = t[4:7] for c in datasets[t]: if len(c) > 3: n = c[3] else: n = "2" s1 = s + "_" + n f.write("%s," % s1) f.write("\n") for trainer in trainers: trainer_plans = [plans] if trainer in overwrite_plans.keys(): trainer_plans = overwrite_plans[trainer] result_per_dataset_here = {} for d in datasets: result_per_dataset_here[d] = {} for p in trainer_plans: name = "%s__%s" % (trainer, p) all_present = True all_trainers.append(name) f.write("%s," % name) for dataset in datasets.keys(): for configuration in datasets[dataset]: summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % ( dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str)) if not isfile(summary_file): all_present = False print(name, dataset, configuration, "has missing summary file") if isfile(summary_file): result = load_json(summary_file)['results'][interested_in]['mean']['Dice'] result_per_dataset_here[dataset][configuration] = result f.write("%02.4f," % result) else: f.write("NA,") f.write("\n") if all_present: valid_trainers.append(name) for d in datasets: for c in datasets[d]: result_per_dataset[d][c].append(result_per_dataset_here[d][c]) invalid_trainers = [i for i in all_trainers if i not in valid_trainers] num_valid = len(valid_trainers) num_datasets = len(datasets.keys()) # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two all_res = np.zeros((num_valid, num_datasets)) for j, d in enumerate(datasets.keys()): ks = list(result_per_dataset[d].keys()) tmp = result_per_dataset[d][ks[0]] for k in ks[1:]: for i in range(len(tmp)): tmp[i] = max(tmp[i], result_per_dataset[d][k][i]) all_res[:, j] = tmp ranks_arr = np.zeros_like(all_res) for d in range(ranks_arr.shape[1]): temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0 ranks = np.empty_like(temp) ranks[temp] = np.arange(len(temp)) ranks_arr[:, d] = ranks mn = np.mean(ranks_arr, 1) for i in np.argsort(mn): print(mn[i], valid_trainers[i]) print() print(valid_trainers[np.argmin(mn)]) ================================================ FILE: unetr_pp/evaluation/model_selection/rank_candidates_cascade.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import network_training_output_dir if __name__ == "__main__": # run collect_all_fold0_results_and_summarize_in_one_csv.py first summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new") output_file = join(network_training_output_dir, "summary_cascade.csv") folds = (0, ) folds_str = "" for f in folds: folds_str += str(f) plans = "nnFormerPlansv2.1" overwrite_plans = { 'nnFormerTrainerCascadeFullRes': ['nnFormerPlans'], } trainers = [ 'nnFormerTrainerCascadeFullRes', 'nnFormerTrainerV2CascadeFullRes_EducatedGuess', 'nnFormerTrainerV2CascadeFullRes_EducatedGuess2', 'nnFormerTrainerV2CascadeFullRes_EducatedGuess3', 'nnFormerTrainerV2CascadeFullRes_lowerLR', 'nnFormerTrainerV2CascadeFullRes', 'nnFormerTrainerV2CascadeFullRes_noConnComp', 'nnFormerTrainerV2CascadeFullRes_shorter_lowerLR', 'nnFormerTrainerV2CascadeFullRes_shorter', 'nnFormerTrainerV2CascadeFullRes_smallerBinStrel', #'', #'', #'', #'', #'', #'', ] datasets = \ { "Task003_Liver": ("3d_cascade_fullres", ), "Task006_Lung": ("3d_cascade_fullres", ), "Task007_Pancreas": ("3d_cascade_fullres", ), "Task008_HepaticVessel": ("3d_cascade_fullres", ), "Task009_Spleen": ("3d_cascade_fullres", ), "Task010_Colon": ("3d_cascade_fullres", ), "Task017_AbdominalOrganSegmentation": ("3d_cascade_fullres", ), #"Task029_LITS": ("3d_cascade_fullres", ), "Task048_KiTS_clean": ("3d_cascade_fullres", ), "Task055_SegTHOR": ("3d_cascade_fullres", ), "Task056_VerSe": ("3d_cascade_fullres", ), #"": ("3d_cascade_fullres", ), } expected_validation_folder = "validation_raw" alternative_validation_folder = "validation" alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True" interested_in = "mean" result_per_dataset = {} for d in datasets: result_per_dataset[d] = {} for c in datasets[d]: result_per_dataset[d][c] = [] valid_trainers = [] all_trainers = [] with open(output_file, 'w') as f: f.write("trainer,") for t in datasets.keys(): s = t[4:7] for c in datasets[t]: s1 = s + "_" + c[3] f.write("%s," % s1) f.write("\n") for trainer in trainers: trainer_plans = [plans] if trainer in overwrite_plans.keys(): trainer_plans = overwrite_plans[trainer] result_per_dataset_here = {} for d in datasets: result_per_dataset_here[d] = {} for p in trainer_plans: name = "%s__%s" % (trainer, p) all_present = True all_trainers.append(name) f.write("%s," % name) for dataset in datasets.keys(): for configuration in datasets[dataset]: summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str)) if not isfile(summary_file): summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % ( dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str)) if not isfile(summary_file): all_present = False print(name, dataset, configuration, "has missing summary file") if isfile(summary_file): result = load_json(summary_file)['results'][interested_in]['mean']['Dice'] result_per_dataset_here[dataset][configuration] = result f.write("%02.4f," % result) else: f.write("NA,") result_per_dataset_here[dataset][configuration] = 0 f.write("\n") if True: valid_trainers.append(name) for d in datasets: for c in datasets[d]: result_per_dataset[d][c].append(result_per_dataset_here[d][c]) invalid_trainers = [i for i in all_trainers if i not in valid_trainers] num_valid = len(valid_trainers) num_datasets = len(datasets.keys()) # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two all_res = np.zeros((num_valid, num_datasets)) for j, d in enumerate(datasets.keys()): ks = list(result_per_dataset[d].keys()) tmp = result_per_dataset[d][ks[0]] for k in ks[1:]: for i in range(len(tmp)): tmp[i] = max(tmp[i], result_per_dataset[d][k][i]) all_res[:, j] = tmp ranks_arr = np.zeros_like(all_res) for d in range(ranks_arr.shape[1]): temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0 ranks = np.empty_like(temp) ranks[temp] = np.arange(len(temp)) ranks_arr[:, d] = ranks mn = np.mean(ranks_arr, 1) for i in np.argsort(mn): print(mn[i], valid_trainers[i]) print() print(valid_trainers[np.argmin(mn)]) ================================================ FILE: unetr_pp/evaluation/model_selection/summarize_results_in_one_json.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from unetr_pp.evaluation.add_mean_dice_to_json import foreground_mean from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import network_training_output_dir import numpy as np def summarize(tasks, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'), output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)): maybe_mkdir_p(output_dir) if len(tasks) == 1 and tasks[0] == "all": tasks = list(range(999)) else: tasks = [int(i) for i in tasks] for model in models: for t in tasks: t = int(t) if not isdir(join(network_training_output_dir, model)): continue task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False) if len(task_name) != 1: print("did not find unique output folder for network %s and task %s" % (model, t)) continue task_name = task_name[0] out_dir_task = join(network_training_output_dir, model, task_name) model_trainers = subdirs(out_dir_task, join=False) for trainer in model_trainers: if trainer.startswith("fold"): continue out_dir = join(out_dir_task, trainer) validation_folders = [] for fld in folds: d = join(out_dir, "fold%d"%fld) if not isdir(d): d = join(out_dir, "fold_%d"%fld) if not isdir(d): break validation_folders += subfolders(d, prefix="validation", join=False) for v in validation_folders: ok = True metrics = OrderedDict() for fld in folds: d = join(out_dir, "fold%d"%fld) if not isdir(d): d = join(out_dir, "fold_%d"%fld) if not isdir(d): ok = False break validation_folder = join(d, v) if not isfile(join(validation_folder, "summary.json")): print("summary.json missing for net %s task %s fold %d" % (model, task_name, fld)) ok = False break metrics_tmp = load_json(join(validation_folder, "summary.json"))["results"]["mean"] for l in metrics_tmp.keys(): if metrics.get(l) is None: metrics[l] = OrderedDict() for m in metrics_tmp[l].keys(): if metrics[l].get(m) is None: metrics[l][m] = [] metrics[l][m].append(metrics_tmp[l][m]) if ok: for l in metrics.keys(): for m in metrics[l].keys(): assert len(metrics[l][m]) == len(folds) metrics[l][m] = np.mean(metrics[l][m]) json_out = OrderedDict() json_out["results"] = OrderedDict() json_out["results"]["mean"] = metrics json_out["task"] = task_name json_out["description"] = model + " " + task_name + " all folds summary" json_out["name"] = model + " " + task_name + " all folds summary" json_out["experiment_name"] = model save_json(json_out, join(out_dir, "summary_allFolds__%s.json" % v)) save_json(json_out, join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v))) foreground_mean(join(out_dir, "summary_allFolds__%s.json" % v)) foreground_mean(join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v))) def summarize2(task_ids, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'), output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)): maybe_mkdir_p(output_dir) if len(task_ids) == 1 and task_ids[0] == "all": task_ids = list(range(999)) else: task_ids = [int(i) for i in task_ids] for model in models: for t in task_ids: if not isdir(join(network_training_output_dir, model)): continue task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False) if len(task_name) != 1: print("did not find unique output folder for network %s and task %s" % (model, t)) continue task_name = task_name[0] out_dir_task = join(network_training_output_dir, model, task_name) model_trainers = subdirs(out_dir_task, join=False) for trainer in model_trainers: if trainer.startswith("fold"): continue out_dir = join(out_dir_task, trainer) validation_folders = [] for fld in folds: fold_output_dir = join(out_dir, "fold_%d"%fld) if not isdir(fold_output_dir): continue validation_folders += subfolders(fold_output_dir, prefix="validation", join=False) validation_folders = np.unique(validation_folders) for v in validation_folders: ok = True metrics = OrderedDict() metrics['mean'] = OrderedDict() metrics['median'] = OrderedDict() metrics['all'] = OrderedDict() for fld in folds: fold_output_dir = join(out_dir, "fold_%d"%fld) if not isdir(fold_output_dir): print("fold missing", model, task_name, trainer, fld) ok = False break validation_folder = join(fold_output_dir, v) if not isdir(validation_folder): print("validation folder missing", model, task_name, trainer, fld, v) ok = False break if not isfile(join(validation_folder, "summary.json")): print("summary.json missing", model, task_name, trainer, fld, v) ok = False break all_metrics = load_json(join(validation_folder, "summary.json"))["results"] # we now need to get the mean and median metrics. We use the mean metrics just to get the # names of computed metics, we ignore the precomputed mean and do it ourselfes again mean_metrics = all_metrics["mean"] all_labels = [i for i in list(mean_metrics.keys()) if i != "mean"] if len(all_labels) == 0: print(v, fld); break all_metrics_names = list(mean_metrics[all_labels[0]].keys()) for l in all_labels: # initialize the data structure, no values are copied yet for k in ['mean', 'median', 'all']: if metrics[k].get(l) is None: metrics[k][l] = OrderedDict() for m in all_metrics_names: if metrics['all'][l].get(m) is None: metrics['all'][l][m] = [] for entry in all_metrics['all']: for l in all_labels: for m in all_metrics_names: metrics['all'][l][m].append(entry[l][m]) # now compute mean and median for l in metrics['all'].keys(): for m in metrics['all'][l].keys(): metrics['mean'][l][m] = np.nanmean(metrics['all'][l][m]) metrics['median'][l][m] = np.nanmedian(metrics['all'][l][m]) if ok: fold_string = "" for f in folds: fold_string += str(f) json_out = OrderedDict() json_out["results"] = OrderedDict() json_out["results"]["mean"] = metrics['mean'] json_out["results"]["median"] = metrics['median'] json_out["task"] = task_name json_out["description"] = model + " " + task_name + "summary folds" + str(folds) json_out["name"] = model + " " + task_name + "summary folds" + str(folds) json_out["experiment_name"] = model save_json(json_out, join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string))) foreground_mean2(join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string))) def foreground_mean2(filename): with open(filename, 'r') as f: res = json.load(f) class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean') and i != '0']) metric_names = res['results']['mean']['1'].keys() res['results']['mean']["mean"] = OrderedDict() res['results']['median']["mean"] = OrderedDict() for m in metric_names: foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids] res['results']['mean']["mean"][m] = np.nanmean(foreground_values) foreground_values = [res['results']['median'][str(i)][m] for i in class_ids] res['results']['median']["mean"][m] = np.nanmean(foreground_values) with open(filename, 'w') as f: json.dump(res, f, indent=4, sort_keys=True) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold " "cross-validation. Running this script requires alle models to have been run " "already. This script will summarize the results of the five folds of all " "models in one json each for easy interpretability") parser.add_argument("-t", '--task_ids', nargs="+", required=True, help="task id. can be 'all'") parser.add_argument("-f", '--folds', nargs="+", required=False, type=int, default=[0, 1, 2, 3, 4]) parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']) args = parser.parse_args() tasks = args.task_ids models = args.models folds = args.folds summarize2(tasks, models, folds=folds, output_dir=join(network_training_output_dir, "summary_jsons_new")) ================================================ FILE: unetr_pp/evaluation/model_selection/summarize_results_with_plans.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * import os from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize from unetr_pp.paths import network_training_output_dir import numpy as np def list_to_string(l, delim=","): st = "%03.3f" % l[0] for i in l[1:]: st += delim + "%03.3f" % i return st def write_plans_to_file(f, plans_file, stage=0, do_linebreak_at_end=True, override_name=None): a = load_pickle(plans_file) stages = list(a['plans_per_stage'].keys()) stages.sort() patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'], a['plans_per_stage'][stages[stage]]['current_spacing'])] median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'], a['plans_per_stage'][stages[stage]]['current_spacing'])] if override_name is None: f.write(plans_file.split("/")[-2] + "__" + plans_file.split("/")[-1]) else: f.write(override_name) f.write(";%d" % stage) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size'])) f.write(";%s" % list_to_string(patch_size_in_mm)) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'])) f.write(";%s" % list_to_string(median_patient_size_in_mm)) f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['current_spacing'])) f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['original_spacing'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes'])) if do_linebreak_at_end: f.write("\n") if __name__ == "__main__": summarize((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 24, 27), output_dir=join(network_training_output_dir, "summary_fold0"), folds=(0,)) base_dir = os.environ['RESULTS_FOLDER'] nnformers = ['nnFormerV2', 'nnFormerV2_zspacing'] task_ids = list(range(99)) with open("summary.csv", 'w') as f: f.write("identifier;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes;patient_dc;global_dc\n") for i in task_ids: for nnformer in nnformers: try: summary_folder = join(base_dir, nnformer, "summary_fold0") if isdir(summary_folder): summary_files = subfiles(summary_folder, join=False, prefix="Task%03.0d_" % i, suffix=".json", sort=True) for s in summary_files: tmp = s.split("__") trainer = tmp[2] expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0]) name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2].split(".")[0] global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json") if not isdir(expected_output_folder) or len(tmp) > 3: if len(tmp) == 2: continue expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0]) name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2] + "__" + tmp[3].split(".")[0] global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json") assert isdir(expected_output_folder), "expected output dir not found" plans_file = join(expected_output_folder, "plans.pkl") assert isfile(plans_file) plans = load_pickle(plans_file) num_stages = len(plans['plans_per_stage']) if num_stages > 1 and tmp[1] == "3d_fullres": stage = 1 elif (num_stages == 1 and tmp[1] == "3d_fullres") or tmp[1] == "3d_lowres": stage = 0 else: print("skipping", s) continue g_dc = load_json(global_dice_json) mn_glob_dc = np.mean(list(g_dc.values())) write_plans_to_file(f, plans_file, stage, False, name) # now read and add result to end of line results = load_json(join(summary_folder, s)) mean_dc = results['results']['mean']['mean']['Dice'] f.write(";%03.3f" % mean_dc) f.write(";%03.3f\n" % mn_glob_dc) print(name, mean_dc) except Exception as e: print(e) ================================================ FILE: unetr_pp/evaluation/region_based_evaluation.py ================================================ from copy import deepcopy from multiprocessing.pool import Pool from batchgenerators.utilities.file_and_folder_operations import * from medpy import metric import SimpleITK as sitk import numpy as np from unetr_pp.configuration import default_num_threads from unetr_pp.postprocessing.consolidate_postprocessing import collect_cv_niftis def get_brats_regions(): """ this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a different labeling convention! :return: """ regions = { "whole tumor": (1, 2, 3), "tumor core": (2, 3), "enhancing tumor": (3,) } return regions def get_KiTS_regions(): regions = { "kidney incl tumor": (1, 2), "tumor": (2,) } return regions def create_region_from_mask(mask, join_labels: tuple): mask_new = np.zeros_like(mask, dtype=np.uint8) for l in join_labels: mask_new[mask == l] = 1 return mask_new def evaluate_case(file_pred: str, file_gt: str, regions): image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt)) image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred)) results = [] for r in regions: mask_pred = create_region_from_mask(image_pred, r) mask_gt = create_region_from_mask(image_gt, r) dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt) results.append(dc) return results def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads): region_names = list(regions.keys()) files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False) files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False) have_no_gt = [i for i in files_in_pred if i not in files_in_gt] assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt" have_no_pred = [i for i in files_in_gt if i not in files_in_pred] if len(have_no_pred) > 0: print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!") files_in_gt.sort() files_in_pred.sort() # run for all cases full_filenames_gt = [join(folder_gt, i) for i in files_in_pred] full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred] p = Pool(processes) res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt))) p.close() p.join() all_results = {r: [] for r in region_names} with open(join(folder_predicted, 'summary.csv'), 'w') as f: f.write("casename") for r in region_names: f.write(",%s" % r) f.write("\n") for i in range(len(files_in_pred)): f.write(files_in_pred[i][:-7]) result_here = res[i] for k, r in enumerate(region_names): dc = result_here[k] f.write(",%02.4f" % dc) all_results[r].append(dc) f.write("\n") f.write('mean') for r in region_names: f.write(",%02.4f" % np.nanmean(all_results[r])) f.write("\n") f.write('median') for r in region_names: f.write(",%02.4f" % np.nanmedian(all_results[r])) f.write("\n") f.write('mean (nan is 1)') for r in region_names: tmp = np.array(all_results[r]) tmp[np.isnan(tmp)] = 1 f.write(",%02.4f" % np.mean(tmp)) f.write("\n") f.write('median (nan is 1)') for r in region_names: tmp = np.array(all_results[r]) tmp[np.isnan(tmp)] = 1 f.write(",%02.4f" % np.median(tmp)) f.write("\n") if __name__ == '__main__': collect_cv_niftis('./', './cv_niftis') evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions()) ================================================ FILE: unetr_pp/evaluation/surface_dice.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from medpy.metric.binary import __surface_distances def normalized_surface_dice(a: np.ndarray, b: np.ndarray, threshold: float, spacing: tuple = None, connectivity=1): """ This implementation differs from the official surface dice implementation! These two are not comparable!!!!! The normalized surface dice is symmetric, so it should not matter whether a or b is the reference image This implementation natively supports 2D and 3D images. Whether other dimensions are supported depends on the __surface_distances implementation in medpy :param a: image 1, must have the same shape as b :param b: image 2, must have the same shape as a :param threshold: distances below this threshold will be counted as true positives. Threshold is in mm, not voxels! (if spacing = (1, 1(, 1)) then one voxel=1mm so the threshold is effectively in voxels) must be a tuple of len dimension(a) :param spacing: how many mm is one voxel in reality? Can be left at None, we then assume an isotropic spacing of 1mm :param connectivity: see scipy.ndimage.generate_binary_structure for more information. I suggest you leave that one alone :return: """ assert all([i == j for i, j in zip(a.shape, b.shape)]), "a and b must have the same shape. a.shape= %s, " \ "b.shape= %s" % (str(a.shape), str(b.shape)) if spacing is None: spacing = tuple([1 for _ in range(len(a.shape))]) a_to_b = __surface_distances(a, b, spacing, connectivity) b_to_a = __surface_distances(b, a, spacing, connectivity) numel_a = len(a_to_b) numel_b = len(b_to_a) tp_a = np.sum(a_to_b <= threshold) / numel_a tp_b = np.sum(b_to_a <= threshold) / numel_b fp = np.sum(a_to_b > threshold) / numel_a fn = np.sum(b_to_a > threshold) / numel_b dc = (tp_a + tp_b) / (tp_a + tp_b + fp + fn + 1e-8) # 1e-8 just so that we don't get div by 0 return dc ================================================ FILE: unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/.gitignore ================================================ # Ignore everything in this directory * # Except this file !.gitignore ================================================ FILE: unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/.gitignore ================================================ # Ignore everything in this directory * # Except this file !.gitignore ================================================ FILE: unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/.gitignore ================================================ # Ignore everything in this directory * # Except this file !.gitignore ================================================ FILE: unetr_pp/evaluation/unetr_pp_tumor_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/.gitignore ================================================ # Ignore everything in this directory * # Except this file !.gitignore ================================================ FILE: unetr_pp/experiment_planning/DatasetAnalyzer.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing import Pool from unetr_pp.configuration import default_num_threads from unetr_pp.paths import nnFormer_raw_data, nnFormer_cropped_data import numpy as np import pickle from unetr_pp.preprocessing.cropping import get_patient_identifiers_from_cropped_files from skimage.morphology import label from collections import OrderedDict class DatasetAnalyzer(object): def __init__(self, folder_with_cropped_data, overwrite=True, num_processes=default_num_threads): """ :param folder_with_cropped_data: :param overwrite: If True then precomputed values will not be used and instead recomputed from the data. False will allow loading of precomputed values. This may be dangerous though if some of the code of this class was changed, therefore the default is True. """ self.num_processes = num_processes self.overwrite = overwrite self.folder_with_cropped_data = folder_with_cropped_data self.sizes = self.spacings = None self.patient_identifiers = get_patient_identifiers_from_cropped_files(self.folder_with_cropped_data) assert isfile(join(self.folder_with_cropped_data, "dataset.json")), \ "dataset.json needs to be in folder_with_cropped_data" self.props_per_case_file = join(self.folder_with_cropped_data, "props_per_case.pkl") self.intensityproperties_file = join(self.folder_with_cropped_data, "intensityproperties.pkl") def load_properties_of_cropped(self, case_identifier): with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f: properties = pickle.load(f) return properties @staticmethod def _check_if_all_in_one_region(seg, regions): res = OrderedDict() for r in regions: new_seg = np.zeros(seg.shape) for c in r: new_seg[seg == c] = 1 labelmap, numlabels = label(new_seg, return_num=True) if numlabels != 1: res[tuple(r)] = False else: res[tuple(r)] = True return res @staticmethod def _collect_class_and_region_sizes(seg, all_classes, vol_per_voxel): volume_per_class = OrderedDict() region_volume_per_class = OrderedDict() for c in all_classes: region_volume_per_class[c] = [] volume_per_class[c] = np.sum(seg == c) * vol_per_voxel labelmap, numregions = label(seg == c, return_num=True) for l in range(1, numregions + 1): region_volume_per_class[c].append(np.sum(labelmap == l) * vol_per_voxel) return volume_per_class, region_volume_per_class def _get_unique_labels(self, patient_identifier): seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1] unique_classes = np.unique(seg) return unique_classes def _load_seg_analyze_classes(self, patient_identifier, all_classes): """ 1) what class is in this training case? 2) what is the size distribution for each class? 3) what is the region size of each class? 4) check if all in one region :return: """ seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1] pkl = load_pickle(join(self.folder_with_cropped_data, patient_identifier) + ".pkl") vol_per_voxel = np.prod(pkl['itk_spacing']) # ad 1) unique_classes = np.unique(seg) # 4) check if all in one region regions = list() regions.append(list(all_classes)) for c in all_classes: regions.append((c, )) all_in_one_region = self._check_if_all_in_one_region(seg, regions) # 2 & 3) region sizes volume_per_class, region_sizes = self._collect_class_and_region_sizes(seg, all_classes, vol_per_voxel) return unique_classes, all_in_one_region, volume_per_class, region_sizes def get_classes(self): datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json")) return datasetjson['labels'] def analyse_segmentations(self): class_dct = self.get_classes() if self.overwrite or not isfile(self.props_per_case_file): p = Pool(self.num_processes) res = p.map(self._get_unique_labels, self.patient_identifiers) p.close() p.join() props_per_patient = OrderedDict() for p, unique_classes in \ zip(self.patient_identifiers, res): props = dict() props['has_classes'] = unique_classes props_per_patient[p] = props save_pickle(props_per_patient, self.props_per_case_file) else: props_per_patient = load_pickle(self.props_per_case_file) return class_dct, props_per_patient def get_sizes_and_spacings_after_cropping(self): sizes = [] spacings = [] # for c in case_identifiers: for c in self.patient_identifiers: properties = self.load_properties_of_cropped(c) sizes.append(properties["size_after_cropping"]) spacings.append(properties["original_spacing"]) return sizes, spacings def get_modalities(self): datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json")) modalities = datasetjson["modality"] modalities = {int(k): modalities[k] for k in modalities.keys()} return modalities def get_size_reduction_by_cropping(self): size_reduction = OrderedDict() for p in self.patient_identifiers: props = self.load_properties_of_cropped(p) shape_before_crop = props["original_size_of_raw_data"] shape_after_crop = props['size_after_cropping'] size_red = np.prod(shape_after_crop) / np.prod(shape_before_crop) size_reduction[p] = size_red return size_reduction def _get_voxels_in_foreground(self, patient_identifier, modality_id): all_data = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'] modality = all_data[modality_id] mask = all_data[-1] > 0 voxels = list(modality[mask][::10]) # no need to take every voxel return voxels @staticmethod def _compute_stats(voxels): if len(voxels) == 0: return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan median = np.median(voxels) mean = np.mean(voxels) sd = np.std(voxels) mn = np.min(voxels) mx = np.max(voxels) percentile_99_5 = np.percentile(voxels, 99.5) percentile_00_5 = np.percentile(voxels, 00.5) return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 def collect_intensity_properties(self, num_modalities): if self.overwrite or not isfile(self.intensityproperties_file): p = Pool(self.num_processes) results = OrderedDict() for mod_id in range(num_modalities): results[mod_id] = OrderedDict() v = p.starmap(self._get_voxels_in_foreground, zip(self.patient_identifiers, [mod_id] * len(self.patient_identifiers))) w = [] for iv in v: w += iv median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 = self._compute_stats(w) local_props = p.map(self._compute_stats, v) props_per_case = OrderedDict() for i, pat in enumerate(self.patient_identifiers): props_per_case[pat] = OrderedDict() props_per_case[pat]['median'] = local_props[i][0] props_per_case[pat]['mean'] = local_props[i][1] props_per_case[pat]['sd'] = local_props[i][2] props_per_case[pat]['mn'] = local_props[i][3] props_per_case[pat]['mx'] = local_props[i][4] props_per_case[pat]['percentile_99_5'] = local_props[i][5] props_per_case[pat]['percentile_00_5'] = local_props[i][6] results[mod_id]['local_props'] = props_per_case results[mod_id]['median'] = median results[mod_id]['mean'] = mean results[mod_id]['sd'] = sd results[mod_id]['mn'] = mn results[mod_id]['mx'] = mx results[mod_id]['percentile_99_5'] = percentile_99_5 results[mod_id]['percentile_00_5'] = percentile_00_5 p.close() p.join() save_pickle(results, self.intensityproperties_file) else: results = load_pickle(self.intensityproperties_file) return results def analyze_dataset(self, collect_intensityproperties=True): # get all spacings and sizes sizes, spacings = self.get_sizes_and_spacings_after_cropping() # get all classes and what classes are in what patients # class min size # region size per class classes = self.get_classes() all_classes = [int(i) for i in classes.keys() if int(i) > 0] # modalities modalities = self.get_modalities() # collect intensity information if collect_intensityproperties: intensityproperties = self.collect_intensity_properties(len(modalities)) else: intensityproperties = None # size reduction by cropping size_reductions = self.get_size_reduction_by_cropping() dataset_properties = dict() dataset_properties['all_sizes'] = sizes dataset_properties['all_spacings'] = spacings dataset_properties['all_classes'] = all_classes dataset_properties['modalities'] = modalities # {idx: modality name} dataset_properties['intensityproperties'] = intensityproperties dataset_properties['size_reductions'] = size_reductions # {patient_id: size_reduction} save_pickle(dataset_properties, join(self.folder_with_cropped_data, "dataset_properties.pkl")) return dataset_properties ================================================ FILE: unetr_pp/experiment_planning/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_11GB.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_v21_11GB(ExperimentPlanner3D_v21): """ Same as ExperimentPlanner3D_v21, but designed to fill a RTX2080 ti (11GB) in fp16 """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_11GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.1_big" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_big_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ We need to adapt ref """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) # use_this_for_batch_size_computation_3D = 520000000 # 505789440 # typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a 2080ti we have 11. Allow for more space # to be used ref = Generic_UNet.use_this_for_batch_size_computation_3D * 11 / 8 here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) # print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_16GB.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_v21_16GB(ExperimentPlanner3D_v21): """ Same as ExperimentPlanner3D_v21, but designed to fill 16GB in fp16 """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.1_16GB" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_16GB_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ We need to adapt ref """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) # use_this_for_batch_size_computation_3D = 520000000 # 505789440 # typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a 2080ti we have 11. Allow for more space # to be used ref = Generic_UNet.use_this_for_batch_size_computation_3D * 16 / 8.5 here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) # print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_32GB.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_v21_32GB(ExperimentPlanner3D_v21): """ Same as ExperimentPlanner3D_v21, but designed to fill a V100 (32GB) in fp16 """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_32GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.1_verybig" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_verybig_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ We need to adapt ref """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) # use_this_for_batch_size_computation_3D = 520000000 # 505789440 # typical ExperimentPlanner3D_v21 configurations use 7.5GB, but on a V100 we have 32. Allow for more space # to be used ref = Generic_UNet.use_this_for_batch_size_computation_3D * 32 / 8 here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) # print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21 from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_v21_3cps(ExperimentPlanner3D_v21): """ have 3x conv-in-lrelu per resolution instead of 2 while remaining in the same memory budget This only works with 3d fullres because we use the same data as ExperimentPlanner3D_v21. Lowres would require to rerun preprocesing (different patch size = different 3d lowres target spacing) """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_3cps, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_3cps_plans_3D.pkl") self.unet_base_num_features = 32 self.conv_per_stage = 3 def run_preprocessing(self, num_threads): pass ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.paths import * class ExperimentPlanner3D_v22(ExperimentPlanner3D_v21): """ """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super().__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.2" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.2_plans_3D.pkl") def get_target_spacing(self): spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0) target_size_mm = np.array(target) * np.array(target_size) # we need to identify datasets for which a different target spacing could be beneficial. These datasets have # the following properties: # - one axis which much lower resolution than the others # - the lowres axis has much less voxels than the others # - (the size in mm of the lowres axis is also reduced) worst_spacing_axis = np.argmax(target) other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] other_spacings = [target[i] for i in other_axes] other_sizes = [target_size[i] for i in other_axes] has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) # we don't use the last one for now #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm) if has_aniso_spacing and has_aniso_voxels: spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) # don't let the spacing of that axis get higher than self.anisotropy_thresholdxthe_other_axes target_spacing_of_that_axis = max(max(other_spacings) * self.anisotropy_threshold, target_spacing_of_that_axis) target[worst_spacing_axis] = target_spacing_of_that_axis return target ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.paths import * class ExperimentPlanner3D_v23(ExperimentPlanner3D_v21): """ """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v23, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.3" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.3_plans_3D.pkl") self.preprocessor_name = "Preprocessor3DDifferentResampling" ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_residual_3DUNet_v21.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.paths import * from unetr_pp.network_architecture.generic_modular_residual_UNet import FabiansUNet class ExperimentPlanner3DFabiansResUNet_v21(ExperimentPlanner3D_v21): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3DFabiansResUNet_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.1"# "nnFormerData_FabiansResUNet_v2.1" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans_FabiansResUNet_v2.1_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ We use FabiansUNet instead of Generic_UNet """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)] blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1] ref = FabiansUNet.use_this_for_3D_configuration here = FabiansUNet.compute_approx_vram_consumption(input_patch_size, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder, blocks_per_stage_decoder, 2, self.unet_min_batch_size,) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)] blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1] here = FabiansUNet.compute_approx_vram_consumption(new_shp, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder, blocks_per_stage_decoder, 2, self.unet_min_batch_size) input_patch_size = new_shp batch_size = FabiansUNet.default_min_batch_size batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, 'num_blocks_encoder': blocks_per_stage_encoder, 'num_blocks_decoder': blocks_per_stage_decoder } return plan def run_preprocessing(self, num_threads): """ On all datasets except 3d fullres on spleen the preprocessed data would look identical to ExperimentPlanner3D_v21 (I tested decathlon data only). Therefore we just reuse the preprocessed data of that other planner :param num_threads: :return: """ pass ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet_v21 import ExperimentPlanner2D_v21 from unetr_pp.paths import * class ExperimentPlanner2D_v21_RGB_scaleTo_0_1(ExperimentPlanner2D_v21): """ used by tutorial unetr_pp.tutorials.custom_preprocessing """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super().__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormer_RGB_scaleTo_0_1" self.plans_fname = join(self.preprocessed_output_folder, "nnFormer_RGB_scaleTo_0_1" + "_plans_2D.pkl") # The custom preprocessor class we intend to use is GenericPreprocessor_scale_uint8_to_0_1. It must be located # in unetr_pp.preprocessing (any file and submodule) and will be found by its name. Make sure to always define # unique names! self.preprocessor_name = 'GenericPreprocessor_scale_uint8_to_0_1' ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.paths import * class ExperimentPlannerCT2(ExperimentPlanner): """ preprocesses CT data with the "CT2" normalization. (clip range comes from training set and is the 0.5 and 99.5 percentile of intensities in foreground) CT = clip to range, then normalize with global mn and sd (computed on foreground in training set) CT2 = clip to range, normalize each case separately with its own mn and std (computed within the area that was in clip_range) """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlannerCT2, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormer_CT2" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "CT2_plans_3D.pkl") def determine_normalization_scheme(self): schemes = OrderedDict() modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) for i in range(num_modalities): if modalities[i] == "CT": schemes[i] = "CT2" else: schemes[i] = "nonCT" return schemes ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.paths import * class ExperimentPlannernonCT(ExperimentPlanner): """ Preprocesses all data in nonCT mode (this is what we use for MRI per default, but here it is applied to CT images as well) """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlannernonCT, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormer_nonCT" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "nonCT_plans_3D.pkl") def determine_normalization_scheme(self): schemes = OrderedDict() modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) for i in range(num_modalities): if modalities[i] == "CT": schemes[i] = "nonCT" else: schemes[i] = "nonCT" return schemes ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_mm.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2 from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlannerIso(ExperimentPlanner): """ attempts to create patches that have an isotropic size (in mm, not voxels) CAREFUL! this one does not support transpose_forward and transpose_backward """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super().__init__(folder_with_cropped_data, preprocessed_output_folder) self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixedisoPatchesInmm_plans_3D.pkl") self.data_identifier = "nnFormer_isoPatchesInmm" def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) ref = Generic_UNet.use_this_for_batch_size_computation_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: # here is the difference to ExperimentPlanner. In the old version we made the aspect ratio match # between patch and new_median_shape, regardless of spacing. It could be better to enforce isotropy # (in mm) instead current_patch_in_mm = new_shp * current_spacing axis_to_be_reduced = np.argsort(current_patch_in_mm)[-1] # from here on it's the same as before tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props_poolLateV2(tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_voxels.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2 from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_IsoPatchesInVoxels(ExperimentPlanner): """ patches that are isotropic in the number of voxels (not mm), such as 128x128x128 allow more voxels to be processed at once because we don't have to do annoying pooling stuff CAREFUL! this one does not support transpose_forward and transpose_backward """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_IsoPatchesInVoxels, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_isoPatchesInVoxels" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixedisoPatchesInVoxels_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases input_patch_size = new_median_shape network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) ref = Generic_UNet.use_this_for_batch_size_computation_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: # find the largest axis. If patch is isotropic, pick the axis with the largest spacing if len(np.unique(new_shp)) == 1: axis_to_be_reduced = np.argsort(current_spacing)[-1] else: axis_to_be_reduced = np.argsort(new_shp)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props_poolLateV2(tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_allConv3x3.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2 from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlannerAllConv3x3(ExperimentPlanner): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlannerAllConv3x3, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "allConv3x3_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that for some organ of interest the acquisition method will most likely be chosen such that the field of view and voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch) The patches created here attempt keep the aspect ratio of the new_median_shape :param current_spacing: :param original_spacing: :param original_shape: :param num_cases: :return: """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) ref = Generic_UNet.use_this_for_batch_size_computation_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props_poolLateV2(tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold for s in range(len(conv_kernel_sizes)): conv_kernel_sizes[s] = [3 for _ in conv_kernel_sizes[s]] plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan def run_preprocessing(self, num_threads): pass ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlannerPoolBasedOnSpacing(ExperimentPlanner): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlannerPoolBasedOnSpacing, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_poolBasedOnSpacing" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "poolBasedOnSpacing_plans_3D.pkl") def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is (2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing. This can cause a larger memory footprint, so it can be beneficial to revise this. Here we are pooling based on the spacing of the data. """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) ref = Generic_UNet.use_this_for_batch_size_computation_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.paths import * class ExperimentPlannerTargetSpacingForAnisoAxis(ExperimentPlanner): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super().__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_targetSpacingForAnisoAxis" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "targetSpacingForAnisoAxis_plans_3D.pkl") def get_target_spacing(self): """ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a pacing of 5 or 6 mm in the low resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially impact performance (due to the low number of slices). """ spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0) target_size_mm = np.array(target) * np.array(target_size) # we need to identify datasets for which a different target spacing could be beneficial. These datasets have # the following properties: # - one axis which much lower resolution than the others # - the lowres axis has much less voxels than the others # - (the size in mm of the lowres axis is also reduced) worst_spacing_axis = np.argmax(target) other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] other_spacings = [target[i] for i in other_axes] other_sizes = [target_size[i] for i in other_axes] has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < max(other_sizes) # we don't use the last one for now #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm) if has_aniso_spacing and has_aniso_voxels: spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) target[worst_spacing_axis] = target_spacing_of_that_axis return target ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21 from unetr_pp.paths import * class ExperimentPlanner3D_v21_customTargetSpacing_2x2x2(ExperimentPlanner3D_v21): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) # we change the data identifier and plans_fname. This will make this experiment planner save the preprocessed # data in a different folder so that they can co-exist with the default (ExperimentPlanner3D_v21). We also # create a custom plans file that will be linked to this data self.data_identifier = "nnFormerData_plans_v2.1_trgSp_2x2x2" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_trgSp_2x2x2_plans_3D.pkl") def get_target_spacing(self): # simply return the desired spacing as np.array return np.array([2., 2., 2.]) # make sure this is float!!!! Not int! ================================================ FILE: unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_noResampling.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from unetr_pp.experiment_planning.alternative_experiment_planning.experiment_planner_baseline_3DUNet_v21_16GB import \ ExperimentPlanner3D_v21_16GB from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ ExperimentPlanner3D_v21 from unetr_pp.paths import * class ExperimentPlanner3D_v21_noResampling(ExperimentPlanner3D_v21): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_noResampling, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_noRes_plans_v2.1" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_noRes_plans_3D.pkl") self.preprocessor_name = "PreprocessorFor3D_NoResampling" def plan_experiment(self): """ DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres :return: """ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm() print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization) spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] all_classes = self.dataset_properties['all_classes'] modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) target_spacing = self.get_target_spacing() new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)] max_spacing_axis = np.argmax(target_spacing) remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] self.transpose_forward = [max_spacing_axis] + remaining_axes self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)] # we base our calculations on the median shape of the datasets median_shape = np.median(np.vstack(new_shapes), 0) print("the median shape of the dataset is ", median_shape) max_shape = np.max(np.vstack(new_shapes), 0) print("the max shape in the dataset is ", max_shape) min_shape = np.min(np.vstack(new_shapes), 0) print("the min shape in the dataset is ", min_shape) print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck") # how many stages will the image pyramid have? self.plans_per_stage = list() target_spacing_transposed = np.array(target_spacing)[self.transpose_forward] median_shape_transposed = np.array(median_shape)[self.transpose_forward] print("the transposed median shape of the dataset is ", median_shape_transposed) print("generating configuration for 3d_fullres") self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed, len(self.list_of_cropped_npz_files), num_modalities, len(all_classes) + 1)) # thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-) # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0: architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64) if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \ architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0: more = False else: more = True if more: pass self.plans_per_stage = self.plans_per_stage[::-1] self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict print(self.plans_per_stage) print("transpose forward", self.transpose_forward) print("transpose backward", self.transpose_backward) normalization_schemes = self.determine_normalization_scheme() only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None # removed training data based postprocessing. This is deprecated # these are independent of the stage plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities, 'modalities': modalities, 'normalization_schemes': normalization_schemes, 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files, 'original_spacings': spacings, 'original_sizes': sizes, 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes), 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features, 'use_mask_for_norm': use_nonzero_mask_for_normalization, 'keep_only_largest_region': only_keep_largest_connected_component, 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class, 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward, 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage, 'preprocessor_name': self.preprocessor_name, 'conv_per_stage': self.conv_per_stage, } self.plans = plans self.save_my_plans() class ExperimentPlanner3D_v21_noResampling_16GB(ExperimentPlanner3D_v21_16GB): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21_noResampling_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_noRes_plans_16GB_v2.1" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_noRes_16GB_plans_3D.pkl") self.preprocessor_name = "PreprocessorFor3D_NoResampling" def plan_experiment(self): """ DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres :return: """ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm() print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization) spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] all_classes = self.dataset_properties['all_classes'] modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) target_spacing = self.get_target_spacing() new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)] max_spacing_axis = np.argmax(target_spacing) remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] self.transpose_forward = [max_spacing_axis] + remaining_axes self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)] # we base our calculations on the median shape of the datasets median_shape = np.median(np.vstack(new_shapes), 0) print("the median shape of the dataset is ", median_shape) max_shape = np.max(np.vstack(new_shapes), 0) print("the max shape in the dataset is ", max_shape) min_shape = np.min(np.vstack(new_shapes), 0) print("the min shape in the dataset is ", min_shape) print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck") # how many stages will the image pyramid have? self.plans_per_stage = list() target_spacing_transposed = np.array(target_spacing)[self.transpose_forward] median_shape_transposed = np.array(median_shape)[self.transpose_forward] print("the transposed median shape of the dataset is ", median_shape_transposed) print("generating configuration for 3d_fullres") self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed, len(self.list_of_cropped_npz_files), num_modalities, len(all_classes) + 1)) # thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-) # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0: architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64) if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \ architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0: more = False else: more = True if more: pass self.plans_per_stage = self.plans_per_stage[::-1] self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict print(self.plans_per_stage) print("transpose forward", self.transpose_forward) print("transpose backward", self.transpose_backward) normalization_schemes = self.determine_normalization_scheme() only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None # removed training data based postprocessing. This is deprecated # these are independent of the stage plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities, 'modalities': modalities, 'normalization_schemes': normalization_schemes, 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files, 'original_spacings': spacings, 'original_sizes': sizes, 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes), 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features, 'use_mask_for_norm': use_nonzero_mask_for_normalization, 'keep_only_largest_region': only_keep_largest_connected_component, 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class, 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward, 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage, 'preprocessor_name': self.preprocessor_name, 'conv_per_stage': self.conv_per_stage, } self.plans = plans self.save_my_plans() ================================================ FILE: unetr_pp/experiment_planning/change_batch_size.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import numpy as np if __name__ == '__main__': input_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl' output_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl' a = load_pickle(input_file) #a['plans_per_stage'][0]['batch_size'] = int(np.floor(6 / 9 * a['plans_per_stage'][0]['batch_size'])) a['plans_per_stage'][0]['batch_size'] = 4 save_pickle(a, output_file) ================================================ FILE: unetr_pp/experiment_planning/common_utils.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from copy import deepcopy from unetr_pp.network_architecture.generic_UNet import Generic_UNet import SimpleITK as sitk import shutil from batchgenerators.utilities.file_and_folder_operations import join def split_4d_nifti(filename, output_folder): img_itk = sitk.ReadImage(filename) dim = img_itk.GetDimension() file_base = filename.split("/")[-1] if dim == 3: shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz")) return elif dim != 4: raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename)) else: img_npy = sitk.GetArrayFromImage(img_itk) spacing = img_itk.GetSpacing() origin = img_itk.GetOrigin() direction = np.array(img_itk.GetDirection()).reshape(4,4) # now modify these to remove the fourth dimension spacing = tuple(list(spacing[:-1])) origin = tuple(list(origin[:-1])) direction = tuple(direction[:-1, :-1].reshape(-1)) for i, t in enumerate(range(img_npy.shape[0])): img = img_npy[t] img_itk_new = sitk.GetImageFromArray(img) img_itk_new.SetSpacing(spacing) img_itk_new.SetOrigin(origin) img_itk_new.SetDirection(direction) sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing): """ :param spacing: :param patch_size: :param min_feature_map_size: min edge length of feature maps in bottleneck :return: """ initial_spacing = deepcopy(spacing) reach = max(initial_spacing) dim = len(patch_size) num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size) net_num_pool_op_kernel_sizes = [] net_conv_kernel_sizes = [] net_numpool = max(num_pool_per_axis) current_spacing = spacing for p in range(net_numpool): reached = [current_spacing[i] / reach > 0.5 for i in range(dim)] pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)] if all(reached): conv = [3] * dim else: conv = [3 if not reached[i] else 1 for i in range(dim)] net_num_pool_op_kernel_sizes.append(pool) net_conv_kernel_sizes.append(conv) current_spacing = [i * j for i, j in zip(current_spacing, pool)] net_conv_kernel_sizes.append([3] * dim) must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here return num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): """ :param spacing: :param patch_size: :param min_feature_map_size: min edge length of feature maps in bottleneck :return: """ dim = len(spacing) current_spacing = deepcopy(list(spacing)) current_size = deepcopy(list(patch_size)) pool_op_kernel_sizes = [] conv_kernel_sizes = [] num_pool_per_axis = [0] * dim while True: # This is a problem because sometimes we have spacing 20, 50, 50 and we want to still keep pooling. # Here we would stop however. This is not what we want! Fixed in get_pool_and_conv_propsv2 min_spacing = min(current_spacing) valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2] axes = [] for a in range(dim): my_spacing = current_spacing[a] partners = [i for i in range(dim) if current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2] if len(partners) > len(axes): axes = partners conv_kernel_size = [3 if i in axes else 1 for i in range(dim)] # exclude axes that we cannot pool further because of min_feature_map_size constraint #before = len(valid_axes_for_pool) valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2*min_feature_map_size] #after = len(valid_axes_for_pool) #if after == 1 and before > 1: # break valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] if len(valid_axes_for_pool) == 0: break #print(current_spacing, current_size) other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] pool_kernel_sizes = [0] * dim for v in valid_axes_for_pool: pool_kernel_sizes[v] = 2 num_pool_per_axis[v] += 1 current_spacing[v] *= 2 current_size[v] = np.ceil(current_size[v] / 2) for nv in other_axes: pool_kernel_sizes[nv] = 1 pool_op_kernel_sizes.append(pool_kernel_sizes) conv_kernel_sizes.append(conv_kernel_size) #print(conv_kernel_sizes) must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here conv_kernel_sizes.append([3]*dim) return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by def get_pool_and_conv_props_v2(spacing, patch_size, min_feature_map_size, max_numpool): """ :param spacing: :param patch_size: :param min_feature_map_size: min edge length of feature maps in bottleneck :return: """ dim = len(spacing) current_spacing = deepcopy(list(spacing)) current_size = deepcopy(list(patch_size)) pool_op_kernel_sizes = [] conv_kernel_sizes = [] num_pool_per_axis = [0] * dim kernel_size = [1] * dim while True: # exclude axes that we cannot pool further because of min_feature_map_size constraint valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] if len(valid_axes_for_pool) < 1: break spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] # find axis that are within factor of 2 within smallest spacing min_spacing_of_valid = min(spacings_of_axes) valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] # max_numpool constraint valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] if len(valid_axes_for_pool) == 1: if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: pass else: break if len(valid_axes_for_pool) < 1: break # now we need to find kernel sizes # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within # factor 2 of min_spacing. Once they are 3 they remain 3 for d in range(dim): if kernel_size[d] == 3: continue else: if spacings_of_axes[d] / min(current_spacing) < 2: kernel_size[d] = 3 other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] pool_kernel_sizes = [0] * dim for v in valid_axes_for_pool: pool_kernel_sizes[v] = 2 num_pool_per_axis[v] += 1 current_spacing[v] *= 2 current_size[v] = np.ceil(current_size[v] / 2) for nv in other_axes: pool_kernel_sizes[nv] = 1 pool_op_kernel_sizes.append(pool_kernel_sizes) conv_kernel_sizes.append(deepcopy(kernel_size)) #print(conv_kernel_sizes) must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here conv_kernel_sizes.append([3]*dim) return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by def get_shape_must_be_divisible_by(net_numpool_per_axis): return 2 ** np.array(net_numpool_per_axis) def pad_shape(shape, must_be_divisible_by): """ pads shape so that it is divisibly by must_be_divisible_by :param shape: :param must_be_divisible_by: :return: """ if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): must_be_divisible_by = [must_be_divisible_by] * len(shape) else: assert len(must_be_divisible_by) == len(shape) new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] for i in range(len(shape)): if shape[i] % must_be_divisible_by[i] == 0: new_shp[i] -= must_be_divisible_by[i] new_shp = np.array(new_shp).astype(int) return new_shp def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4): network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int) network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis] return network_numpool_per_axis if __name__ == '__main__': # trying to fix https://github.com/MIC-DKFZ/nnFormer/issues/261 median_shape = [24, 504, 512] spacing = [5.9999094, 0.50781202, 0.50781202] num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by = get_pool_and_conv_props_poolLateV2(median_shape, min_feature_map_size=4, max_numpool=999, spacing=spacing) ================================================ FILE: unetr_pp/experiment_planning/experiment_planner_baseline_2DUNet.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil import unetr_pp import numpy as np from batchgenerators.utilities.file_and_folder_operations import load_pickle, subfiles from multiprocessing.pool import Pool from unetr_pp.configuration import default_num_threads from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.experiment_planning.utils import add_classes_in_slice_info from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * from unetr_pp.preprocessing.preprocessing import PreprocessorFor2D from unetr_pp.training.model_restore import recursive_find_python_class class ExperimentPlanner2D(ExperimentPlanner): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner2D, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = default_data_identifier + "_2D" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "_plans_2D.pkl") self.unet_base_num_features = 30 self.unet_max_num_filters = 512 self.unet_max_numpool = 999 self.preprocessor_name = "PreprocessorFor2D" def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases input_patch_size = new_median_shape[1:] network_numpool, net_pool_kernel_sizes, net_conv_kernel_sizes, input_patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) estimated_gpu_ram_consumption = Generic_UNet.compute_approx_vram_consumption(input_patch_size, network_numpool, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, net_pool_kernel_sizes, conv_per_stage=self.conv_per_stage) batch_size = int(np.floor(Generic_UNet.use_this_for_batch_size_computation_2D / estimated_gpu_ram_consumption * Generic_UNet.DEFAULT_BATCH_SIZE_2D)) if batch_size < self.unet_min_batch_size: raise RuntimeError("This framework is not made to process patches this large. We will add patch-based " "2D networks later. Sorry for the inconvenience") # check if batch size is too large (more than 5 % of dataset) max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) batch_size = max(1, min(batch_size, max_batch_size)) plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_numpool, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'pool_op_kernel_sizes': net_pool_kernel_sizes, 'conv_kernel_sizes': net_conv_kernel_sizes, 'do_dummy_2D_data_aug': False } return plan def plan_experiment(self): use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm() print("Are we using the nonzero maks for normalizaion?", use_nonzero_mask_for_normalization) spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] all_classes = self.dataset_properties['all_classes'] modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) target_spacing = self.get_target_spacing() new_shapes = np.array([np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]) max_spacing_axis = np.argmax(target_spacing) remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] self.transpose_forward = [max_spacing_axis] + remaining_axes self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)] # we base our calculations on the median shape of the datasets median_shape = np.median(np.vstack(new_shapes), 0) print("the median shape of the dataset is ", median_shape) max_shape = np.max(np.vstack(new_shapes), 0) print("the max shape in the dataset is ", max_shape) min_shape = np.min(np.vstack(new_shapes), 0) print("the min shape in the dataset is ", min_shape) print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck") # how many stages will the image pyramid have? self.plans_per_stage = [] target_spacing_transposed = np.array(target_spacing)[self.transpose_forward] median_shape_transposed = np.array(median_shape)[self.transpose_forward] print("the transposed median shape of the dataset is ", median_shape_transposed) self.plans_per_stage.append( self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed, num_cases=len(self.list_of_cropped_npz_files), num_modalities=num_modalities, num_classes=len(all_classes) + 1), ) print(self.plans_per_stage) self.plans_per_stage = self.plans_per_stage[::-1] self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict normalization_schemes = self.determine_normalization_scheme() # deprecated only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None # these are independent of the stage plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities, 'modalities': modalities, 'normalization_schemes': normalization_schemes, 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files, 'original_spacings': spacings, 'original_sizes': sizes, 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes), 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features, 'use_mask_for_norm': use_nonzero_mask_for_normalization, 'keep_only_largest_region': only_keep_largest_connected_component, 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class, 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward, 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage, 'preprocessor_name': self.preprocessor_name, } self.plans = plans self.save_my_plans() ================================================ FILE: unetr_pp/experiment_planning/experiment_planner_baseline_2DUNet_v21.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * import numpy as np class ExperimentPlanner2D_v21(ExperimentPlanner2D): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner2D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "nnFormerData_plans_v2.1_2D" self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlansv2.1_plans_2D.pkl") self.unet_base_num_features = 32 def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases input_patch_size = new_median_shape[1:] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) # we pretend to use 30 feature maps. This will yield the same configuration as in V1. The larger memory # footpring of 32 vs 30 is mor ethan offset by the fp16 training. We make fp16 training default # Reason for 32 vs 30 feature maps is that 32 is faster in fp16 training (because multiple of 8) ref = Generic_UNet.use_this_for_batch_size_computation_2D * Generic_UNet.DEFAULT_BATCH_SIZE_2D / 2 # for batch size 2 here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, 30, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape[1:])[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing[1:], tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) # print(new_shp) batch_size = int(np.floor(ref / here) * 2) input_patch_size = new_shp if batch_size < self.unet_min_batch_size: raise RuntimeError("This should not happen") # check if batch size is too large (more than 5 % of dataset) max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) batch_size = max(1, min(batch_size, max_batch_size)) plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, 'do_dummy_2D_data_aug': False } return plan ================================================ FILE: unetr_pp/experiment_planning/experiment_planner_baseline_3DUNet.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from collections import OrderedDict from copy import deepcopy import unetr_pp import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2 from unetr_pp.experiment_planning.utils import create_lists_from_splitted_dataset from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * from unetr_pp.preprocessing.cropping import get_case_identifier_from_npz from unetr_pp.training.model_restore import recursive_find_python_class class ExperimentPlanner(object): def __init__(self, folder_with_cropped_data, preprocessed_output_folder): self.folder_with_cropped_data = folder_with_cropped_data self.preprocessed_output_folder = preprocessed_output_folder self.list_of_cropped_npz_files = subfiles(self.folder_with_cropped_data, True, None, ".npz", True) self.preprocessor_name = "GenericPreprocessor" assert isfile(join(self.folder_with_cropped_data, "dataset_properties.pkl")), \ "folder_with_cropped_data must contain dataset_properties.pkl" self.dataset_properties = load_pickle(join(self.folder_with_cropped_data, "dataset_properties.pkl")) self.plans_per_stage = OrderedDict() self.plans = OrderedDict() self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "fixed_plans_3D.pkl") self.data_identifier = default_data_identifier self.transpose_forward = [0, 1, 2] self.transpose_backward = [0, 1, 2] self.unet_base_num_features = Generic_UNet.BASE_NUM_FEATURES_3D self.unet_max_num_filters = 320 self.unet_max_numpool = 999 self.unet_min_batch_size = 2 self.unet_featuremap_min_edge_length = 4 self.target_spacing_percentile = 50 self.anisotropy_threshold = 3 self.how_much_of_a_patient_must_the_network_see_at_stage0 = 4 # 1/4 of a patient self.batch_size_covers_max_percent_of_dataset = 0.05 # all samples in the batch together cannot cover more # than 5% of the entire dataset self.conv_per_stage = 2 def get_target_spacing(self): spacings = self.dataset_properties['all_spacings'] # target = np.median(np.vstack(spacings), 0) # if target spacing is very anisotropic we may want to not downsample the axis with the worst spacing # uncomment after mystery task submission """worst_spacing_axis = np.argmax(target) if max(target) > (2.5 * min(target)): spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 5) target[worst_spacing_axis] = target_spacing_of_that_axis""" target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) return target def save_my_plans(self): with open(self.plans_fname, 'wb') as f: pickle.dump(self.plans, f) def load_my_plans(self): self.plans = load_pickle(self.plans_fname) self.plans_per_stage = self.plans['plans_per_stage'] self.dataset_properties = self.plans['dataset_properties'] self.transpose_forward = self.plans['transpose_forward'] self.transpose_backward = self.plans['transpose_backward'] def determine_postprocessing(self): pass """ Spoiler: This is unused, postprocessing was removed. Ignore it. :return: print("determining postprocessing...") props_per_patient = self.dataset_properties['segmentation_props_per_patient'] all_region_keys = [i for k in props_per_patient.keys() for i in props_per_patient[k]['only_one_region'].keys()] all_region_keys = list(set(all_region_keys)) only_keep_largest_connected_component = OrderedDict() for r in all_region_keys: all_results = [props_per_patient[k]['only_one_region'][r] for k in props_per_patient.keys()] only_keep_largest_connected_component[tuple(r)] = all(all_results) print("Postprocessing: only_keep_largest_connected_component", only_keep_largest_connected_component) all_classes = self.dataset_properties['all_classes'] classes = [i for i in all_classes if i > 0] props_per_patient = self.dataset_properties['segmentation_props_per_patient'] min_size_per_class = OrderedDict() for c in classes: all_num_voxels = [] for k in props_per_patient.keys(): all_num_voxels.append(props_per_patient[k]['volume_per_class'][c]) if len(all_num_voxels) > 0: min_size_per_class[c] = np.percentile(all_num_voxels, 1) * MIN_SIZE_PER_CLASS_FACTOR else: min_size_per_class[c] = np.inf min_region_size_per_class = OrderedDict() for c in classes: region_sizes = [l for k in props_per_patient for l in props_per_patient[k]['region_volume_per_class'][c]] if len(region_sizes) > 0: min_region_size_per_class[c] = min(region_sizes) # we don't need that line but better safe than sorry, right? min_region_size_per_class[c] = min(min_region_size_per_class[c], min_size_per_class[c]) else: min_region_size_per_class[c] = 0 print("Postprocessing: min_size_per_class", min_size_per_class) print("Postprocessing: min_region_size_per_class", min_region_size_per_class) return only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class """ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that for some organ of interest the acquisition method will most likely be chosen such that the field of view and voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch) The patches created here attempt keep the aspect ratio of the new_median_shape :param current_spacing: :param original_spacing: :param original_shape: :param num_cases: :return: """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) ref = Generic_UNet.use_this_for_batch_size_computation_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props_poolLateV2(tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, current_spacing) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) # print(new_shp) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan def plan_experiment(self): use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm() print("Are we using the nonzero mask for normalizaion?", use_nonzero_mask_for_normalization) spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] all_classes = self.dataset_properties['all_classes'] modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) target_spacing = self.get_target_spacing() new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)] max_spacing_axis = np.argmax(target_spacing) remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] self.transpose_forward = [max_spacing_axis] + remaining_axes self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)] # we base our calculations on the median shape of the datasets median_shape = np.median(np.vstack(new_shapes), 0) print("the median shape of the dataset is ", median_shape) max_shape = np.max(np.vstack(new_shapes), 0) print("the max shape in the dataset is ", max_shape) min_shape = np.min(np.vstack(new_shapes), 0) print("the min shape in the dataset is ", min_shape) print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck") # how many stages will the image pyramid have? self.plans_per_stage = list() target_spacing_transposed = np.array(target_spacing)[self.transpose_forward] median_shape_transposed = np.array(median_shape)[self.transpose_forward] print("the transposed median shape of the dataset is ", median_shape_transposed) print("generating configuration for 3d_fullres") self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed, len(self.list_of_cropped_npz_files), num_modalities, len(all_classes) + 1)) # thanks Zakiyi (https://github.com/MIC-DKFZ/nnFormer/issues/61) for spotting this bug :-) # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0: architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64) if np.prod(median_shape) / architecture_input_voxels_here < \ self.how_much_of_a_patient_must_the_network_see_at_stage0: more = False else: more = True if more: print("generating configuration for 3d_lowres") # if we are doing more than one stage then we want the lowest stage to have exactly # HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0 (this is 4 by default so the number of voxels in the # median shape of the lowest stage must be 4 times as much as the network can process at once (128x128x128 by # default). Problem is that we are downsampling higher resolution axes before we start downsampling the # out-of-plane axis. We could probably/maybe do this analytically but I am lazy, so here # we do it the dumb way lowres_stage_spacing = deepcopy(target_spacing) num_voxels = np.prod(median_shape, dtype=np.float64) while num_voxels > self.how_much_of_a_patient_must_the_network_see_at_stage0 * architecture_input_voxels_here: max_spacing = max(lowres_stage_spacing) if np.any((max_spacing / lowres_stage_spacing) > 2): lowres_stage_spacing[(max_spacing / lowres_stage_spacing) > 2] \ *= 1.01 else: lowres_stage_spacing *= 1.01 num_voxels = np.prod(target_spacing / lowres_stage_spacing * median_shape, dtype=np.float64) lowres_stage_spacing_transposed = np.array(lowres_stage_spacing)[self.transpose_forward] new = self.get_properties_for_stage(lowres_stage_spacing_transposed, target_spacing_transposed, median_shape_transposed, len(self.list_of_cropped_npz_files), num_modalities, len(all_classes) + 1) architecture_input_voxels_here = np.prod(new['patch_size'], dtype=np.int64) if 2 * np.prod(new['median_patient_size_in_voxels'], dtype=np.int64) < np.prod( self.plans_per_stage[0]['median_patient_size_in_voxels'], dtype=np.int64): self.plans_per_stage.append(new) self.plans_per_stage = self.plans_per_stage[::-1] self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict print(self.plans_per_stage) print("transpose forward", self.transpose_forward) print("transpose backward", self.transpose_backward) normalization_schemes = self.determine_normalization_scheme() only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None # removed training data based postprocessing. This is deprecated # these are independent of the stage plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities, 'modalities': modalities, 'normalization_schemes': normalization_schemes, 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files, 'original_spacings': spacings, 'original_sizes': sizes, 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes), 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features, 'use_mask_for_norm': use_nonzero_mask_for_normalization, 'keep_only_largest_region': only_keep_largest_connected_component, 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class, 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward, 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage, 'preprocessor_name': self.preprocessor_name, 'conv_per_stage': self.conv_per_stage, } self.plans = plans self.save_my_plans() def determine_normalization_scheme(self): schemes = OrderedDict() modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) for i in range(num_modalities): if modalities[i] == "CT" or modalities[i] == 'ct': schemes[i] = "CT" else: schemes[i] = "nonCT" return schemes def save_properties_of_cropped(self, case_identifier, properties): with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f) def load_properties_of_cropped(self, case_identifier): with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f: properties = pickle.load(f) return properties def determine_whether_to_use_mask_for_norm(self): # only use the nonzero mask for normalization of the cropping based on it resulted in a decrease in # image size (this is an indication that the data is something like brats/isles and then we want to # normalize in the brain region only) modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) use_nonzero_mask_for_norm = OrderedDict() for i in range(num_modalities): if "CT" in modalities[i]: use_nonzero_mask_for_norm[i] = False else: all_size_reductions = [] for k in self.dataset_properties['size_reductions'].keys(): all_size_reductions.append(self.dataset_properties['size_reductions'][k]) if np.median(all_size_reductions) < 3 / 4.: print("using nonzero mask for normalization") use_nonzero_mask_for_norm[i] = True else: print("not using nonzero mask for normalization") use_nonzero_mask_for_norm[i] = False for c in self.list_of_cropped_npz_files: case_identifier = get_case_identifier_from_npz(c) properties = self.load_properties_of_cropped(case_identifier) properties['use_nonzero_mask_for_norm'] = use_nonzero_mask_for_norm self.save_properties_of_cropped(case_identifier, properties) use_nonzero_mask_for_normalization = use_nonzero_mask_for_norm return use_nonzero_mask_for_normalization def write_normalization_scheme_to_patients(self): """ This is used for test set preprocessing :return: """ for c in self.list_of_cropped_npz_files: case_identifier = get_case_identifier_from_npz(c) properties = self.load_properties_of_cropped(case_identifier) properties['use_nonzero_mask_for_norm'] = self.plans['use_mask_for_norm'] self.save_properties_of_cropped(case_identifier, properties) def run_preprocessing(self, num_threads): if os.path.isdir(join(self.preprocessed_output_folder, "gt_segmentations")): shutil.rmtree(join(self.preprocessed_output_folder, "gt_segmentations")) shutil.copytree(join(self.folder_with_cropped_data, "gt_segmentations"), join(self.preprocessed_output_folder, "gt_segmentations")) normalization_schemes = self.plans['normalization_schemes'] use_nonzero_mask_for_normalization = self.plans['use_mask_for_norm'] intensityproperties = self.plans['dataset_properties']['intensityproperties'] preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")], self.preprocessor_name, current_module="unetr_pp.preprocessing") assert preprocessor_class is not None preprocessor = preprocessor_class(normalization_schemes, use_nonzero_mask_for_normalization, self.transpose_forward, intensityproperties) target_spacings = [i["current_spacing"] for i in self.plans_per_stage.values()] if self.plans['num_stages'] > 1 and not isinstance(num_threads, (list, tuple)): num_threads = (default_num_threads, num_threads) elif self.plans['num_stages'] == 1 and isinstance(num_threads, (list, tuple)): num_threads = num_threads[-1] preprocessor.run(target_spacings, self.folder_with_cropped_data, self.preprocessed_output_folder, self.plans['data_identifier'], num_threads) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("-t", "--task_ids", nargs="+", help="list of int") parser.add_argument("-p", action="store_true", help="set this if you actually want to run the preprocessing. If " "this is not set then this script will only create the plans file") parser.add_argument("-tl", type=int, required=False, default=8, help="num_threads_lowres") parser.add_argument("-tf", type=int, required=False, default=8, help="num_threads_fullres") args = parser.parse_args() task_ids = args.task_ids run_preprocessing = args.p tl = args.tl tf = args.tf tasks = [] for i in task_ids: i = int(i) candidates = subdirs(nnFormer_cropped_data, prefix="Task%03.0d" % i, join=False) assert len(candidates) == 1 tasks.append(candidates[0]) for t in tasks: try: print("\n\n\n", t) cropped_out_dir = os.path.join(nnFormer_cropped_data, t) preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t) splitted_4d_output_dir_task = os.path.join(nnFormer_raw_data, t) lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task) dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False) _ = dataset_analyzer.analyze_dataset() # this will write output files that will be used by the ExperimentPlanner maybe_mkdir_p(preprocessing_output_dir_this_task) shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task) shutil.copy(join(nnFormer_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task) threads = (tl, tf) print("number of threads: ", threads, "\n") exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task) exp_planner.plan_experiment() if run_preprocessing: exp_planner.run_preprocessing(threads) except Exception as e: print(e) ================================================ FILE: unetr_pp/experiment_planning/experiment_planner_baseline_3DUNet_v21.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.paths import * class ExperimentPlanner3D_v21(ExperimentPlanner): """ Combines ExperimentPlannerPoolBasedOnSpacing and ExperimentPlannerTargetSpacingForAnisoAxis We also increase the base_num_features to 32. This is solely because mixed precision training with 3D convs and amp is A LOT faster if the number of filters is divisible by 8 """ def __init__(self, folder_with_cropped_data, preprocessed_output_folder): super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) self.data_identifier = "unetr_pp_Data_plans_v2.1" self.plans_fname = join(self.preprocessed_output_folder, "unetr_pp_Plansv2.1_plans_3D.pkl") self.unet_base_num_features = 32 def get_target_spacing(self): """ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially impact performance (due to the low number of slices). """ spacings = self.dataset_properties['all_spacings'] sizes = self.dataset_properties['all_sizes'] target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) # This should be used to determine the new median shape. The old implementation is not 100% correct. # Fixed in 2.4 # sizes = [np.array(i) / target * np.array(j) for i, j in zip(spacings, sizes)] target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0) target_size_mm = np.array(target) * np.array(target_size) # we need to identify datasets for which a different target spacing could be beneficial. These datasets have # the following properties: # - one axis which much lower resolution than the others # - the lowres axis has much less voxels than the others # - (the size in mm of the lowres axis is also reduced) worst_spacing_axis = np.argmax(target) other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] other_spacings = [target[i] for i in other_axes] other_sizes = [target_size[i] for i in other_axes] has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) # we don't use the last one for now #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm) if has_aniso_spacing and has_aniso_voxels: spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) # don't let the spacing of that axis get higher than the other axes if target_spacing_of_that_axis < max(other_spacings): target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5 target[worst_spacing_axis] = target_spacing_of_that_axis return target def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, num_modalities, num_classes): """ ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is (2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing. This can cause a larger memory footprint, so it can be beneficial to revise this. Here we are pooling based on the spacing of the data. """ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) dataset_num_voxels = np.prod(new_median_shape) * num_cases # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t # input_patch_size = new_median_shape # compute how many voxels are one mm input_patch_size = 1 / np.array(current_spacing) # normalize voxels per mm input_patch_size /= input_patch_size.mean() # create an isotropic patch of size 512x512x512mm input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value input_patch_size = np.round(input_patch_size).astype(int) # clip it to the median shape of the dataset because patches larger then that make not much sense input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)] network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size, self.unet_featuremap_min_edge_length, self.unet_max_numpool) # we compute as if we were using only 30 feature maps. We can do that because fp16 training is the standard # now. That frees up some space. The decision to go with 32 is solely due to the speedup we get (non-multiples # of 8 are not supported in nvidia amp) ref = Generic_UNet.use_this_for_batch_size_computation_3D * self.unet_base_num_features / \ Generic_UNet.BASE_NUM_FEATURES_3D here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) while here > ref: axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1] tmp = deepcopy(new_shp) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by_new = \ get_pool_and_conv_props(current_spacing, tmp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] # we have to recompute numpool now: network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp, self.unet_featuremap_min_edge_length, self.unet_max_numpool, ) here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, self.unet_base_num_features, self.unet_max_num_filters, num_modalities, num_classes, pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage) #print(new_shp) #print(here, ref) input_patch_size = new_shp batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3 batch_size = int(np.floor(max(ref / here, 1) * batch_size)) # check if batch size is too large max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / np.prod(input_patch_size, dtype=np.int64)).astype(int) max_batch_size = max(max_batch_size, self.unet_min_batch_size) batch_size = max(1, min(batch_size, max_batch_size)) do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, } return plan ================================================ FILE: unetr_pp/experiment_planning/nnFormer_convert_decathlon_task.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.experiment_planning.utils import split_4d from unetr_pp.utilities.file_endings import remove_trailing_slash def crawl_and_remove_hidden_from_decathlon(folder): folder = remove_trailing_slash(folder) assert folder.split('/')[-1].startswith("Task"), "This does not seem to be a decathlon folder. Please give me a " \ "folder that starts with TaskXX and has the subfolders imagesTr, " \ "labelsTr and imagesTs" subf = subfolders(folder, join=False) assert 'imagesTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \ "folder that starts with TaskXX and has the subfolders imagesTr, " \ "labelsTr and imagesTs" assert 'imagesTs' in subf, "This does not seem to be a decathlon folder. Please give me a " \ "folder that starts with TaskXX and has the subfolders imagesTr, " \ "labelsTr and imagesTs" assert 'labelsTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \ "folder that starts with TaskXX and has the subfolders imagesTr, " \ "labelsTr and imagesTs" _ = [os.remove(i) for i in subfiles(folder, prefix=".")] _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTr'), prefix=".")] _ = [os.remove(i) for i in subfiles(join(folder, 'labelsTr'), prefix=".")] _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTs'), prefix=".")] def main(): import argparse parser = argparse.ArgumentParser(description="The MSD provides data as 4D Niftis with the modality being the first" " dimension. We think this may be cumbersome for some users and " "therefore expect 3D niftixs instead, with one file per modality. " "This utility will convert 4D MSD data into the format nnU-Net " "expects") parser.add_argument("-i", help="Input folder. Must point to a TaskXX_TASKNAME folder as downloaded from the MSD " "website", required=False, default="/home/maaz/PycharmProjects/nnFormer/DATASET/nnFormer_raw/nnFormer_raw_data" "/Task02_Synapse") parser.add_argument("-p", required=False, default=default_num_threads, type=int, help="Use this to specify how many processes are used to run the script. " "Default is %d" % default_num_threads) parser.add_argument("-output_task_id", required=False, default=None, type=int, help="If specified, this will overwrite the task id in the output folder. If unspecified, the " "task id of the input folder will be used.") args = parser.parse_args() crawl_and_remove_hidden_from_decathlon(args.i) split_4d(args.i, args.p, args.output_task_id) if __name__ == "__main__": main() ================================================ FILE: unetr_pp/experiment_planning/nnFormer_plan_and_preprocess.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unetr_pp from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer from unetr_pp.experiment_planning.utils import crop from unetr_pp.paths import * import shutil from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name from unetr_pp.preprocessing.sanity_checks import verify_dataset_integrity from unetr_pp.training.model_restore import recursive_find_python_class def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("-t", "--task_ids", nargs="+", help="List of integers belonging to the task ids you wish to run" " experiment planning and preprocessing for. Each of these " "ids must, have a matching folder 'TaskXXX_' in the raw " "data folder") parser.add_argument("-pl3d", "--planner3d", type=str, default="ExperimentPlanner3D_v21", help="Name of the ExperimentPlanner class for the full resolution 3D U-Net and U-Net cascade. " "Default is ExperimentPlanner3D_v21. Can be 'None', in which case these U-Nets will not be " "configured") parser.add_argument("-pl2d", "--planner2d", type=str, default="ExperimentPlanner2D_v21", help="Name of the ExperimentPlanner class for the 2D U-Net. Default is ExperimentPlanner2D_v21. " "Can be 'None', in which case this U-Net will not be configured") parser.add_argument("-no_pp", action="store_true", help="Set this flag if you dont want to run the preprocessing. If this is set then this script " "will only run the experiment planning and create the plans file") parser.add_argument("-tl", type=int, required=False, default=8, help="Number of processes used for preprocessing the low resolution data for the 3D low " "resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of " "RAM") parser.add_argument("-tf", type=int, required=False, default=8, help="Number of processes used for preprocessing the full resolution data of the 2D U-Net and " "3D U-Net. Don't overdo it or you will run out of RAM") parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", help="set this flag to check the dataset integrity. This is useful and should be done once for " "each dataset!") args = parser.parse_args() task_ids = args.task_ids dont_run_preprocessing = args.no_pp tl = args.tl tf = args.tf planner_name3d = args.planner3d planner_name2d = args.planner2d if planner_name3d == "None": planner_name3d = None if planner_name2d == "None": planner_name2d = None # we need raw data tasks = [] for i in task_ids: i = int(i) task_name = convert_id_to_task_name(i) if args.verify_dataset_integrity: verify_dataset_integrity(join(nnFormer_raw_data, task_name)) crop(task_name, False, tf) tasks.append(task_name) search_in = join(unetr_pp.__path__[0], "experiment_planning") if planner_name3d is not None: planner_3d = recursive_find_python_class([search_in], planner_name3d, current_module="unetr_pp.experiment_planning") if planner_3d is None: raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in " "unetr_pp.experiment_planning" % planner_name3d) else: planner_3d = None if planner_name2d is not None: planner_2d = recursive_find_python_class([search_in], planner_name2d, current_module="unetr_pp.experiment_planning") if planner_2d is None: raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in " "unetr_pp.experiment_planning" % planner_name2d) else: planner_2d = None for t in tasks: print("\n\n\n", t) cropped_out_dir = os.path.join(nnFormer_cropped_data, t) preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t) #splitted_4d_output_dir_task = os.path.join(nnFormer_raw_data, t) #lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task) # we need to figure out if we need the intensity propoerties. We collect them only if one of the modalities is CT dataset_json = load_json(join(cropped_out_dir, 'dataset.json')) modalities = list(dataset_json["modality"].values()) collect_intensityproperties = True if (("CT" in modalities) or ("ct" in modalities)) else False dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False, num_processes=tf) # this class creates the fingerprint _ = dataset_analyzer.analyze_dataset(collect_intensityproperties) # this will write output files that will be used by the ExperimentPlanner maybe_mkdir_p(preprocessing_output_dir_this_task) shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task) shutil.copy(join(nnFormer_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task) threads = (tl, tf) print("number of threads: ", threads, "\n") if planner_3d is not None: exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task) exp_planner.plan_experiment() if not dont_run_preprocessing: # double negative, yooo exp_planner.run_preprocessing(threads) if planner_2d is not None: exp_planner = planner_2d(cropped_out_dir, preprocessing_output_dir_this_task) exp_planner.plan_experiment() if not dont_run_preprocessing: # double negative, yooo exp_planner.run_preprocessing(threads) if __name__ == "__main__": main() ================================================ FILE: unetr_pp/experiment_planning/summarize_plans.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import preprocessing_output_dir # This file is intended to double check nnFormers design choices. It is intended to be used for developent purposes only def summarize_plans(file): plans = load_pickle(file) print("num_classes: ", plans['num_classes']) print("modalities: ", plans['modalities']) print("use_mask_for_norm", plans['use_mask_for_norm']) print("keep_only_largest_region", plans['keep_only_largest_region']) print("min_region_size_per_class", plans['min_region_size_per_class']) print("min_size_per_class", plans['min_size_per_class']) print("normalization_schemes", plans['normalization_schemes']) print("stages...\n") for i in range(len(plans['plans_per_stage'])): print("stage: ", i) print(plans['plans_per_stage'][i]) print("") def write_plans_to_file(f, plans_file): print(plans_file) a = load_pickle(plans_file) stages = list(a['plans_per_stage'].keys()) stages.sort() for stage in stages: patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'], a['plans_per_stage'][stages[stage]]['current_spacing'])] median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'], a['plans_per_stage'][stages[stage]]['current_spacing'])] f.write(plans_file.split("/")[-2]) f.write(";%s" % plans_file.split("/")[-1]) f.write(";%d" % stage) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size'])) f.write(";%s" % str([str("%03.2f" % i) for i in patch_size_in_mm])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'])) f.write(";%s" % str([str("%03.2f" % i) for i in median_patient_size_in_mm])) f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['current_spacing']])) f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['original_spacing']])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes'])) f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes'])) f.write(";%s" % str(a['data_identifier'])) f.write("\n") if __name__ == "__main__": base_dir = './'#preprocessing_output_dir'' task_dirs = [i for i in subdirs(base_dir, join=False, prefix="Task") if i.find("BrainTumor") == -1 and i.find("MSSeg") == -1] print("found %d tasks" % len(task_dirs)) with open("2019_02_06_plans_summary.csv", 'w') as f: f.write("task;plans_file;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes\n") for t in task_dirs: print(t) tmp = join(base_dir, t) plans_files = [i for i in subfiles(tmp, suffix=".pkl", join=False) if i.find("_plans_") != -1 and i.find("Dgx2") == -1] for p in plans_files: write_plans_to_file(f, join(tmp, p)) f.write("\n") ================================================ FILE: unetr_pp/experiment_planning/utils.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import pickle import shutil from collections import OrderedDict from multiprocessing import Pool import numpy as np from batchgenerators.utilities.file_and_folder_operations import join, isdir, maybe_mkdir_p, subfiles, subdirs, isfile from unetr_pp.configuration import default_num_threads from unetr_pp.experiment_planning.DatasetAnalyzer import DatasetAnalyzer from unetr_pp.experiment_planning.common_utils import split_4d_nifti from unetr_pp.paths import nnFormer_raw_data, nnFormer_cropped_data, preprocessing_output_dir from unetr_pp.preprocessing.cropping import ImageCropper def split_4d(input_folder, num_processes=default_num_threads, overwrite_task_output_id=None): assert isdir(join(input_folder, "imagesTr")) and isdir(join(input_folder, "labelsTr")) and \ isfile(join(input_folder, "dataset.json")), \ "The input folder must be a valid Task folder from the Medical Segmentation Decathlon with at least the " \ "imagesTr and labelsTr subfolders and the dataset.json file" while input_folder.endswith("/"): input_folder = input_folder[:-1] full_task_name = input_folder.split("/")[-1] assert full_task_name.startswith("Task"), "The input folder must point to a folder that starts with TaskXX_" first_underscore = full_task_name.find("_") assert first_underscore == 6, "Input folder start with TaskXX with XX being a 3-digit id: 00, 01, 02 etc" input_task_id = int(full_task_name[4:6]) if overwrite_task_output_id is None: overwrite_task_output_id = input_task_id task_name = full_task_name[7:] output_folder = join(nnFormer_raw_data, "Task%03.0d_" % overwrite_task_output_id + task_name) if isdir(output_folder): shutil.rmtree(output_folder) files = [] output_dirs = [] maybe_mkdir_p(output_folder) for subdir in ["imagesTr", "imagesTs"]: curr_out_dir = join(output_folder, subdir) if not isdir(curr_out_dir): os.mkdir(curr_out_dir) curr_dir = join(input_folder, subdir) nii_files = [join(curr_dir, i) for i in os.listdir(curr_dir) if i.endswith(".nii.gz")] nii_files.sort() for n in nii_files: files.append(n) output_dirs.append(curr_out_dir) shutil.copytree(join(input_folder, "labelsTr"), join(output_folder, "labelsTr")) p = Pool(num_processes) p.starmap(split_4d_nifti, zip(files, output_dirs)) p.close() p.join() shutil.copy(join(input_folder, "dataset.json"), output_folder) def create_lists_from_splitted_dataset(base_folder_splitted): lists = [] json_file = join(base_folder_splitted, "dataset.json") with open(json_file) as jsn: d = json.load(jsn) training_files = d['training'] num_modalities = len(d['modality'].keys()) for tr in training_files: cur_pat = [] for mod in range(num_modalities): cur_pat.append(join(base_folder_splitted, "imagesTr", tr['image'].split("/")[-1][:-7] + "_%04.0d.nii.gz" % mod)) cur_pat.append(join(base_folder_splitted, "labelsTr", tr['label'].split("/")[-1])) lists.append(cur_pat) return lists, {int(i): d['modality'][str(i)] for i in d['modality'].keys()} def create_lists_from_splitted_dataset_folder(folder): """ does not rely on dataset.json :param folder: :return: """ caseIDs = get_caseIDs_from_splitted_dataset_folder(folder) list_of_lists = [] for f in caseIDs: list_of_lists.append(subfiles(folder, prefix=f, suffix=".nii.gz", join=True, sort=True)) return list_of_lists def get_caseIDs_from_splitted_dataset_folder(folder): files = subfiles(folder, suffix=".nii.gz", join=False) # all files must be .nii.gz and have 4 digit modality index files = [i[:-12] for i in files] # only unique patient ids files = np.unique(files) return files def crop(task_string, override=False, num_threads=default_num_threads): cropped_out_dir = join(nnFormer_cropped_data, task_string) maybe_mkdir_p(cropped_out_dir) if override and isdir(cropped_out_dir): shutil.rmtree(cropped_out_dir) maybe_mkdir_p(cropped_out_dir) splitted_4d_output_dir_task = join(nnFormer_raw_data, task_string) lists, _ = create_lists_from_splitted_dataset(splitted_4d_output_dir_task) imgcrop = ImageCropper(num_threads, cropped_out_dir) imgcrop.run_cropping(lists, overwrite_existing=override) shutil.copy(join(nnFormer_raw_data, task_string, "dataset.json"), cropped_out_dir) def analyze_dataset(task_string, override=False, collect_intensityproperties=True, num_processes=default_num_threads): cropped_out_dir = join(nnFormer_cropped_data, task_string) dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=override, num_processes=num_processes) _ = dataset_analyzer.analyze_dataset(collect_intensityproperties) def plan_and_preprocess(task_string, processes_lowres=default_num_threads, processes_fullres=3, no_preprocessing=False): from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner preprocessing_output_dir_this_task_train = join(preprocessing_output_dir, task_string) cropped_out_dir = join(nnFormer_cropped_data, task_string) maybe_mkdir_p(preprocessing_output_dir_this_task_train) shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task_train) shutil.copy(join(nnFormer_raw_data, task_string, "dataset.json"), preprocessing_output_dir_this_task_train) exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task_train) exp_planner.plan_experiment() if not no_preprocessing: exp_planner.run_preprocessing((processes_lowres, processes_fullres)) exp_planner = ExperimentPlanner2D(cropped_out_dir, preprocessing_output_dir_this_task_train) exp_planner.plan_experiment() if not no_preprocessing: exp_planner.run_preprocessing(processes_fullres) # write which class is in which slice to all training cases (required to speed up 2D Dataloader) # This is done for all data so that if we wanted to use them with 2D we could do so if not no_preprocessing: p = Pool(default_num_threads) # if there is more than one my_data_identifier (different brnaches) then this code will run for all of them if # they start with the same string. not problematic, but not pretty stages = [i for i in subdirs(preprocessing_output_dir_this_task_train, join=True, sort=True) if i.split("/")[-1].find("stage") != -1] for s in stages: print(s.split("/")[-1]) list_of_npz_files = subfiles(s, True, None, ".npz", True) list_of_pkl_files = [i[:-4]+".pkl" for i in list_of_npz_files] all_classes = [] for pk in list_of_pkl_files: with open(pk, 'rb') as f: props = pickle.load(f) all_classes_tmp = np.array(props['classes']) all_classes.append(all_classes_tmp[all_classes_tmp >= 0]) p.map(add_classes_in_slice_info, zip(list_of_npz_files, list_of_pkl_files, all_classes)) p.close() p.join() def add_classes_in_slice_info(args): """ We need this for 2D dataloader with oversampling. As of now it will detect slices that contain specific classes at run time, meaning it needs to iterate over an entire patient just to extract one slice. That is obviously bad, so we are doing this once beforehand and just give the dataloader the info it needs in the patients pkl file. """ npz_file, pkl_file, all_classes = args seg_map = np.load(npz_file)['data'][-1] with open(pkl_file, 'rb') as f: props = pickle.load(f) #if props.get('classes_in_slice_per_axis') is not None: print(pkl_file) # this will be a dict of dict where the first dict encodes the axis along which a slice is extracted in its keys. # The second dict (value of first dict) will have all classes as key and as values a list of all slice ids that # contain this class classes_in_slice = OrderedDict() for axis in range(3): other_axes = tuple([i for i in range(3) if i != axis]) classes_in_slice[axis] = OrderedDict() for c in all_classes: valid_slices = np.where(np.sum(seg_map == c, axis=other_axes) > 0)[0] classes_in_slice[axis][c] = valid_slices number_of_voxels_per_class = OrderedDict() for c in all_classes: number_of_voxels_per_class[c] = np.sum(seg_map == c) props['classes_in_slice_per_axis'] = classes_in_slice props['number_of_voxels_per_class'] = number_of_voxels_per_class with open(pkl_file, 'wb') as f: pickle.dump(props, f) ================================================ FILE: unetr_pp/inference/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/inference/predict.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from copy import deepcopy from typing import Tuple, Union, List import numpy as np from batchgenerators.augmentations.utils import resize_segmentation from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax, save_segmentation_nifti from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing import Process, Queue import torch import SimpleITK as sitk import shutil from multiprocessing import Pool from unetr_pp.postprocessing.connected_components import load_remove_save, load_postprocessing from unetr_pp.training.model_restore import load_model_and_checkpoint_files from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc from unetr_pp.training.network_training.Trainer_synapse import Trainer_synapse from unetr_pp.utilities.one_hot_encoding import to_one_hot def preprocess_save_to_queue(preprocess_fn, q, list_of_lists, output_files, segs_from_prev_stage, classes, transpose_forward): # suppress output # sys.stdout = open(os.devnull, 'w') errors_in = [] for i, l in enumerate(list_of_lists): try: output_file = output_files[i] print("preprocessing", output_file) d, _, dct = preprocess_fn(l) # print(output_file, dct) if segs_from_prev_stage[i] is not None: assert isfile(segs_from_prev_stage[i]) and segs_from_prev_stage[i].endswith( ".nii.gz"), "segs_from_prev_stage" \ " must point to a " \ "segmentation file" seg_prev = sitk.GetArrayFromImage(sitk.ReadImage(segs_from_prev_stage[i])) # check to see if shapes match img = sitk.GetArrayFromImage(sitk.ReadImage(l[0])) assert all([i == j for i, j in zip(seg_prev.shape, img.shape)]), "image and segmentation from previous " \ "stage don't have the same pixel array " \ "shape! image: %s, seg_prev: %s" % \ (l[0], segs_from_prev_stage[i]) seg_prev = seg_prev.transpose(transpose_forward) seg_reshaped = resize_segmentation(seg_prev, d.shape[1:], order=1, cval=0) seg_reshaped = to_one_hot(seg_reshaped, classes) d = np.vstack((d, seg_reshaped)).astype(np.float32) """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" print(d.shape) if np.prod(d.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save, 4 because float32 is 4 bytes print( "This output is too large for python process-process communication. " "Saving output temporarily to disk") np.save(output_file[:-7] + ".npy", d) d = output_file[:-7] + ".npy" q.put((output_file, (d, dct))) except KeyboardInterrupt: raise KeyboardInterrupt except Exception as e: print("error in", l) print(e) q.put("end") if len(errors_in) > 0: print("There were some errors in the following cases:", errors_in) print("These cases were ignored.") else: print("This worker has ended successfully, no errors to report") # restore output # sys.stdout = sys.__stdout__ def preprocess_multithreaded(trainer, list_of_lists, output_files, num_processes=2, segs_from_prev_stage=None): if segs_from_prev_stage is None: segs_from_prev_stage = [None] * len(list_of_lists) num_processes = min(len(list_of_lists), num_processes) classes = list(range(1, trainer.num_classes)) assert isinstance(trainer, Trainer_acdc) or isinstance(trainer, Trainer_synapse) q = Queue(1) processes = [] for i in range(num_processes): pr = Process(target=preprocess_save_to_queue, args=(trainer.preprocess_patient, q, list_of_lists[i::num_processes], output_files[i::num_processes], segs_from_prev_stage[i::num_processes], classes, trainer.plans['transpose_forward'])) pr.start() processes.append(pr) try: end_ctr = 0 while end_ctr != num_processes: item = q.get() if item == "end": end_ctr += 1 continue else: yield item finally: for p in processes: if p.is_alive(): p.terminate() # this should not happen but better safe than sorry right p.join() q.close() def predict_cases(model, list_of_lists, output_filenames, folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint", segmentation_export_kwargs: dict = None): """ :param segmentation_export_kwargs: :param model: folder where the model is saved, must contain fold_x subfolders :param list_of_lists: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...] :param output_filenames: [output_file_case0.nii.gz, output_file_case1.nii.gz, ...] :param folds: default: (0, 1, 2, 3, 4) (but can also be 'all' or a subset of the five folds, for example use (0, ) for using only fold_0 :param save_npz: default: False :param num_threads_preprocessing: :param num_threads_nifti_save: :param segs_from_prev_stage: :param do_tta: default: True, can be set to False for a 8x speedup at the cost of a reduced segmentation quality :param overwrite_existing: default: True :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init :return: """ assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) # if save_npz=True then we should also check for missing npz files not_done_idx = [i for i, j in enumerate(cleaned_output_files) if (not isfile(j)) or (save_npz and not isfile(j[:-7] + '.npz'))] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) if segmentation_export_kwargs is None: if 'segmentation_export_params' in trainer.plans.keys(): force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z'] interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") all_output_files = [] for preprocessed in preprocessing: output_filename, (d, dct) = preprocessed all_output_files.append(all_output_files) if isinstance(d, str): data = np.load(d) os.remove(d) d = data print("predicting", output_filename) softmax = [] for p in params: trainer.load_checkpoint_ram(p, False) softmax.append(trainer.predict_preprocessed_data_return_seg_and_softmax( d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision)[1][None]) softmax = np.vstack(softmax) softmax_mean = np.mean(softmax, 0) transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') softmax_mean = softmax_mean.transpose([0] + [i + 1 for i in transpose_backward]) if save_npz: npz_file = output_filename[:-7] + ".npz" else: npz_file = None if hasattr(trainer, 'regions_class_order'): region_class_order = trainer.regions_class_order else: region_class_order = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" bytes_per_voxel = 4 if all_in_gpu: bytes_per_voxel = 2 # if all_in_gpu then the return value is half (float16) if np.prod(softmax_mean.shape) > (2e9 / bytes_per_voxel * 0.85): # * 0.85 just to be save print( "This output is too large for python process-process communication. Saving output temporarily to disk") np.save(output_filename[:-7] + ".npy", softmax_mean) softmax_mean = output_filename[:-7] + ".npy" results.append(pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_mean, output_filename, dct, interpolation_order, region_class_order, None, None, npz_file, None, force_separate_z, interpolation_order_z),) )) print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.abspath(os.path.dirname(output_filenames[0]))) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def predict_cases_fast(model, list_of_lists, output_filenames, folds, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint", segmentation_export_kwargs: dict = None): assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) if segmentation_export_kwargs is None: if 'segmentation_export_params' in trainer.plans.keys(): force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z'] interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") for preprocessed in preprocessing: print("getting data from preprocessor") output_filename, (d, dct) = preprocessed print("got something") if isinstance(d, str): print("what I got is a string, so I need to load a file") data = np.load(d) os.remove(d) d = data # preallocate the output arrays # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time) softmax_aggr = None # np.zeros((trainer.num_classes, *d.shape[1:]), dtype=np.float16) all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int) print("predicting", output_filename) for i, p in enumerate(params): trainer.load_checkpoint_ram(p, False) res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision) if len(params) > 1: # otherwise we dont need this and we can save ourselves the time it takes to copy that print("aggregating softmax") if softmax_aggr is None: softmax_aggr = res[1] else: softmax_aggr += res[1] all_seg_outputs[i] = res[0] print("obtaining segmentation map") if len(params) > 1: # we dont need to normalize the softmax by 1 / len(params) because this would not change the outcome of the argmax seg = softmax_aggr.argmax(0) else: seg = all_seg_outputs[0] print("applying transpose_backward") transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') seg = seg.transpose([i for i in transpose_backward]) print("initializing segmentation export") results.append(pool.starmap_async(save_segmentation_nifti, ((seg, output_filename, dct, interpolation_order, force_separate_z, interpolation_order_z),) )) print("done") print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.dirname(output_filenames[0])) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def predict_cases_fastest(model, list_of_lists, output_filenames, folds, num_threads_preprocessing, num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True, overwrite_existing=False, all_in_gpu=True, step_size=0.5, checkpoint_name="model_final_checkpoint"): assert len(list_of_lists) == len(output_filenames) if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames) pool = Pool(num_threads_nifti_save) results = [] cleaned_output_files = [] for o in output_filenames: dr, f = os.path.split(o) if len(dr) > 0: maybe_mkdir_p(dr) if not f.endswith(".nii.gz"): f, _ = os.path.splitext(f) f = f + ".nii.gz" cleaned_output_files.append(join(dr, f)) if not overwrite_existing: print("number of cases:", len(list_of_lists)) not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)] cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx] list_of_lists = [list_of_lists[i] for i in not_done_idx] if segs_from_prev_stage is not None: segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx] print("number of cases that still need to be predicted:", len(cleaned_output_files)) print("emptying cuda cache") torch.cuda.empty_cache() print("loading parameters for folds,", folds) trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision, checkpoint_name=checkpoint_name) print("starting preprocessing generator") preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing, segs_from_prev_stage) print("starting prediction...") for preprocessed in preprocessing: print("getting data from preprocessor") output_filename, (d, dct) = preprocessed print("got something") if isinstance(d, str): print("what I got is a string, so I need to load a file") data = np.load(d) os.remove(d) d = data # preallocate the output arrays # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time) all_softmax_outputs = np.zeros((len(params), trainer.num_classes, *d.shape[1:]), dtype=np.float16) all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int) print("predicting", output_filename) for i, p in enumerate(params): trainer.load_checkpoint_ram(p, False) res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu, mixed_precision=mixed_precision) if len(params) > 1: # otherwise we dont need this and we can save ourselves the time it takes to copy that all_softmax_outputs[i] = res[1] all_seg_outputs[i] = res[0] print("aggregating predictions") if len(params) > 1: softmax_mean = np.mean(all_softmax_outputs, 0) seg = softmax_mean.argmax(0) else: seg = all_seg_outputs[0] print("applying transpose_backward") transpose_forward = trainer.plans.get('transpose_forward') if transpose_forward is not None: transpose_backward = trainer.plans.get('transpose_backward') seg = seg.transpose([i for i in transpose_backward]) print("initializing segmentation export") results.append(pool.starmap_async(save_segmentation_nifti, ((seg, output_filename, dct, 0, None),) )) print("done") print("inference done. Now waiting for the segmentation export to finish...") _ = [i.get() for i in results] # now apply postprocessing # first load the postprocessing properties if they are present. Else raise a well visible warning results = [] pp_file = join(model, "postprocessing.json") if isfile(pp_file): print("postprocessing...") shutil.copy(pp_file, os.path.dirname(output_filenames[0])) # for_which_classes stores for which of the classes everything but the largest connected component needs to be # removed for_which_classes, min_valid_obj_size = load_postprocessing(pp_file) results.append(pool.starmap_async(load_remove_save, zip(output_filenames, output_filenames, [for_which_classes] * len(output_filenames), [min_valid_obj_size] * len(output_filenames)))) _ = [i.get() for i in results] else: print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run " "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is " "%s" % model) pool.close() pool.join() def check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities): print("This model expects %d input modalities for each image" % expected_num_modalities) files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True) maybe_case_ids = np.unique([i[:-12] for i in files]) remaining = deepcopy(files) missing = [] assert len(files) > 0, "input folder did not contain any images (expected to find .nii.gz file endings)" # now check if all required files are present and that no unexpected files are remaining for c in maybe_case_ids: for n in range(expected_num_modalities): expected_output_file = c + "_%04.0d.nii.gz" % n if not isfile(join(input_folder, expected_output_file)): missing.append(expected_output_file) else: remaining.remove(expected_output_file) print("Found %d unique case ids, here are some examples:" % len(maybe_case_ids), np.random.choice(maybe_case_ids, min(len(maybe_case_ids), 10))) print("If they don't look right, make sure to double check your filenames. They must end with _0000.nii.gz etc") if len(remaining) > 0: print("found %d unexpected remaining files in the folder. Here are some examples:" % len(remaining), np.random.choice(remaining, min(len(remaining), 10))) if len(missing) > 0: print("Some files are missing:") print(missing) raise RuntimeError("missing files in input_folder") return maybe_case_ids def predict_from_folder(model: str, input_folder: str, output_folder: str, folds: Union[Tuple[int], List[int]], save_npz: bool, num_threads_preprocessing: int, num_threads_nifti_save: int, lowres_segmentations: Union[str, None], part_id: int, num_parts: int, tta: bool, mixed_precision: bool = True, overwrite_existing: bool = True, mode: str = 'normal', overwrite_all_in_gpu: bool = None, step_size: float = 0.5, checkpoint_name: str = "model_final_checkpoint", segmentation_export_kwargs: dict = None): """ here we use the standard naming scheme to generate list_of_lists and output_files needed by predict_cases :param model: :param input_folder: :param output_folder: :param folds: :param save_npz: :param num_threads_preprocessing: :param num_threads_nifti_save: :param lowres_segmentations: :param part_id: :param num_parts: :param tta: :param mixed_precision: :param overwrite_existing: if not None then it will be overwritten with whatever is in there. None is default (no overwrite) :return: """ maybe_mkdir_p(output_folder) shutil.copy(join(model, 'plans.pkl'), output_folder) assert isfile(join(model, "plans.pkl")), "Folder with saved model weights must contain a plans.pkl file" expected_num_modalities = load_pickle(join(model, "plans.pkl"))['num_modalities'] # check input folder integrity case_ids = check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities) output_files = [join(output_folder, i + ".nii.gz") for i in case_ids] all_files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True) list_of_lists = [[join(input_folder, i) for i in all_files if i[:len(j)].startswith(j) and len(i) == (len(j) + 12)] for j in case_ids] if lowres_segmentations is not None: assert isdir(lowres_segmentations), "if lowres_segmentations is not None then it must point to a directory" lowres_segmentations = [join(lowres_segmentations, i + ".nii.gz") for i in case_ids] assert all([isfile(i) for i in lowres_segmentations]), "not all lowres_segmentations files are present. " \ "(I was searching for case_id.nii.gz in that folder)" lowres_segmentations = lowres_segmentations[part_id::num_parts] else: lowres_segmentations = None if mode == "normal": if overwrite_all_in_gpu is None: all_in_gpu = False else: all_in_gpu = overwrite_all_in_gpu return predict_cases(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name, segmentation_export_kwargs=segmentation_export_kwargs) elif mode == "fast": if overwrite_all_in_gpu is None: all_in_gpu = True else: all_in_gpu = overwrite_all_in_gpu assert save_npz is False return predict_cases_fast(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name, segmentation_export_kwargs=segmentation_export_kwargs) elif mode == "fastest": if overwrite_all_in_gpu is None: all_in_gpu = True else: all_in_gpu = overwrite_all_in_gpu assert save_npz is False return predict_cases_fastest(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing, all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=checkpoint_name) else: raise ValueError("unrecognized mode. Must be normal, fast or fastest") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct" " order (same as training). Files must be named " "CASENAME_XXXX.nii.gz where XXXX is the modality " "identifier (0000, 0001, etc)", required=True) parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions") parser.add_argument('-m', '--model_output_folder', help='model output folder. Will automatically discover the folds ' 'that were ' 'run and use those as an ensemble', required=True) parser.add_argument('-f', '--folds', nargs='+', default='None', help="folds to use for prediction. Default is None " "which means that folds will be detected " "automatically in the model output folder") parser.add_argument('-z', '--save_npz', required=False, action='store_true', help="use this if you want to ensemble" " these predictions with those of" " other models. Softmax " "probabilities will be saved as " "compresed numpy arrays in " "output_folder and can be merged " "between output_folders with " "merge_predictions.py") parser.add_argument('-l', '--lowres_segmentations', required=False, default='None', help="if model is the highres " "stage of the cascade then you need to use -l to specify where the segmentations of the " "corresponding lowres unet are. Here they are required to do a prediction") parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (for example via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_parts", type=int, required=False, default=1, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help= "Determines many background processes will be used for data preprocessing. Reduce this if you " "run into out of memory (RAM) problems. Default: 6") parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help= "Determines many background processes will be used for segmentation export. Reduce this if you " "run into out of memory (RAM) problems. Default: 2") parser.add_argument("--tta", required=False, type=int, default=1, help="Set to 0 to disable test time data " "augmentation (speedup of factor " "4(2D)/8(3D)), " "lower quality segmentations") parser.add_argument("--overwrite_existing", required=False, type=int, default=1, help="Set this to 0 if you need " "to resume a previous " "prediction. Default: 1 " "(=existing segmentations " "in output_folder will be " "overwritten)") parser.add_argument("--mode", type=str, default="normal", required=False) parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True") parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch") # parser.add_argument("--interp_order", required=False, default=3, type=int, # help="order of interpolation for segmentations, has no effect if mode=fastest") # parser.add_argument("--interp_order_z", required=False, default=0, type=int, # help="order of interpolation along z is z is done differently") # parser.add_argument("--force_separate_z", required=False, default="None", type=str, # help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest") parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False, help='Predictions are done with mixed precision by default. This improves speed and reduces ' 'the required vram. If you want to disable mixed precision you can set this flag. Note ' 'that yhis is not recommended (mixed precision is ~2x faster!)') args = parser.parse_args() input_folder = args.input_folder output_folder = args.output_folder part_id = args.part_id num_parts = args.num_parts model = args.model_output_folder folds = args.folds save_npz = args.save_npz lowres_segmentations = args.lowres_segmentations num_threads_preprocessing = args.num_threads_preprocessing num_threads_nifti_save = args.num_threads_nifti_save tta = args.tta step_size = args.step_size # interp_order = args.interp_order # interp_order_z = args.interp_order_z # force_separate_z = args.force_separate_z # if force_separate_z == "None": # force_separate_z = None # elif force_separate_z == "False": # force_separate_z = False # elif force_separate_z == "True": # force_separate_z = True # else: # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) overwrite = args.overwrite_existing mode = args.mode all_in_gpu = args.all_in_gpu if lowres_segmentations == "None": lowres_segmentations = None if isinstance(folds, list): if folds[0] == 'all' and len(folds) == 1: pass else: folds = [int(i) for i in folds] elif folds == "None": folds = None else: raise ValueError("Unexpected value for argument folds") if tta == 0: tta = False elif tta == 1: tta = True else: raise ValueError("Unexpected value for tta, Use 1 or 0") if overwrite == 0: overwrite = False elif overwrite == 1: overwrite = True else: raise ValueError("Unexpected value for overwrite, Use 1 or 0") assert all_in_gpu in ['None', 'False', 'True'] if all_in_gpu == "None": all_in_gpu = None elif all_in_gpu == "True": all_in_gpu = True elif all_in_gpu == "False": all_in_gpu = False predict_from_folder(model, input_folder, output_folder, folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, part_id, num_parts, tta, mixed_precision=not args.disable_mixed_precision, overwrite_existing=overwrite, mode=mode, overwrite_all_in_gpu=all_in_gpu, step_size=step_size) ================================================ FILE: unetr_pp/inference/predict_simple.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch from unetr_pp.inference.predict import predict_from_folder from unetr_pp.paths import default_plans_identifier, network_training_output_dir, default_trainer from batchgenerators.utilities.file_and_folder_operations import join, isdir from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier from unetr_pp.run.default_configuration import get_default_configuration def main(): parser = argparse.ArgumentParser() parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct" " order (same as training). Files must be named " "CASENAME_XXXX.nii.gz where XXXX is the modality " "identifier (0000, 0001, etc)", required=True) parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions") parser.add_argument('-t', '--task_name', help='task name or task ID, required.', default=default_plans_identifier, required=True) parser.add_argument('-tr', '--trainer_class_name', help='Name of the nnFormerTrainer used for 2D U-Net, full resolution 3D U-Net and low resolution ' 'U-Net. The default is %s. If you are running inference with the cascade and the folder ' 'pointed to by --lowres_segmentations does not contain the segmentation maps generated by ' 'the low resolution U-Net then the low resolution segmentation maps will be automatically ' 'generated. For this case, make sure to set the trainer class here that matches your ' '--cascade_trainer_class_name (this part can be ignored if defaults are used).' % default_trainer, required=False, default=default_trainer) parser.add_argument('-ctr', '--cascade_trainer_class_name', help="Trainer class name used for predicting the 3D full resolution U-Net part of the cascade." "Default is %s" % default_trainer, required=False, default=default_trainer) parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres", default="3d_fullres", required=False) parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing', default=default_plans_identifier, required=False) parser.add_argument('-f', '--folds', nargs='+', default='None', help="folds to use for prediction. Default is None which means that folds will be detected " "automatically in the model output folder") parser.add_argument('-z', '--save_npz', required=False, action='store_true', help="use this if you want to ensemble these predictions with those of other models. Softmax " "probabilities will be saved as compressed numpy arrays in output_folder and can be " "merged between output_folders with nnFormer_ensemble_predictions") parser.add_argument('-l', '--lowres_segmentations', required=False, default='None', help="if model is the highres stage of the cascade then you can use this folder to provide " "predictions from the low resolution 3D U-Net. If this is left at default, the " "predictions will be generated automatically (provided that the 3D low resolution U-Net " "network weights are present") parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (for example via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_parts", type=int, required=False, default=1, help="Used to parallelize the prediction of " "the folder over several GPUs. If you " "want to use n GPUs to predict this " "folder you need to run this command " "n times with --part_id=0, ... n-1 and " "--num_parts=n (each with a different " "GPU (via " "CUDA_VISIBLE_DEVICES=X)") parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help= "Determines many background processes will be used for data preprocessing. Reduce this if you " "run into out of memory (RAM) problems. Default: 6") parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help= "Determines many background processes will be used for segmentation export. Reduce this if you " "run into out of memory (RAM) problems. Default: 2") parser.add_argument("--disable_tta", required=False, default=False, action="store_true", help="set this flag to disable test time data augmentation via mirroring. Speeds up inference " "by roughly factor 4 (2D) or 8 (3D)") parser.add_argument("--overwrite_existing", required=False, default=False, action="store_true", help="Set this flag if the target folder contains predictions that you would like to overwrite") parser.add_argument("--mode", type=str, default="normal", required=False, help="Hands off!") parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True. " "Do not touch.") parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch") # parser.add_argument("--interp_order", required=False, default=3, type=int, # help="order of interpolation for segmentations, has no effect if mode=fastest. Do not touch this.") # parser.add_argument("--interp_order_z", required=False, default=0, type=int, # help="order of interpolation along z is z is done differently. Do not touch this.") # parser.add_argument("--force_separate_z", required=False, default="None", type=str, # help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest. " # "Do not touch this.") parser.add_argument('-chk', help='checkpoint name, default: model_final_checkpoint', required=False, default='model_final_checkpoint') parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False, help='Predictions are done with mixed precision by default. This improves speed and reduces ' 'the required vram. If you want to disable mixed precision you can set this flag. Note ' 'that yhis is not recommended (mixed precision is ~2x faster!)') args = parser.parse_args() input_folder = args.input_folder output_folder = args.output_folder part_id = args.part_id num_parts = args.num_parts folds = args.folds save_npz = args.save_npz lowres_segmentations = args.lowres_segmentations num_threads_preprocessing = args.num_threads_preprocessing num_threads_nifti_save = args.num_threads_nifti_save disable_tta = args.disable_tta step_size = args.step_size # interp_order = args.interp_order # interp_order_z = args.interp_order_z # force_separate_z = args.force_separate_z overwrite_existing = args.overwrite_existing mode = args.mode all_in_gpu = args.all_in_gpu model = args.model trainer_class_name = args.trainer_class_name cascade_trainer_class_name = args.cascade_trainer_class_name task_name = args.task_name if not task_name.startswith("Task"): task_id = int(task_name) task_name = convert_id_to_task_name(task_id) assert model in ["2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"], "-m must be 2d, 3d_lowres, 3d_fullres or " \ "3d_cascade_fullres" # if force_separate_z == "None": # force_separate_z = None # elif force_separate_z == "False": # force_separate_z = False # elif force_separate_z == "True": # force_separate_z = True # else: # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z) if lowres_segmentations == "None": lowres_segmentations = None if isinstance(folds, list): if folds[0] == 'all' and len(folds) == 1: pass else: folds = [int(i) for i in folds] elif folds == "None": folds = None else: raise ValueError("Unexpected value for argument folds") assert all_in_gpu in ['None', 'False', 'True'] if all_in_gpu == "None": all_in_gpu = None elif all_in_gpu == "True": all_in_gpu = True elif all_in_gpu == "False": all_in_gpu = False # we need to catch the case where model is 3d cascade fullres and the low resolution folder has not been set. # In that case we need to try and predict with 3d low res first if model == "3d_cascade_fullres" and lowres_segmentations is None: print("lowres_segmentations is None. Attempting to predict 3d_lowres first...") assert part_id == 0 and num_parts == 1, "if you don't specify a --lowres_segmentations folder for the " \ "inference of the cascade, custom values for part_id and num_parts " \ "are not supported. If you wish to have multiple parts, please " \ "run the 3d_lowres inference first (separately)" model_folder_name = join(network_training_output_dir, "3d_lowres", task_name, trainer_class_name + "__" + args.plans_identifier) assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name lowres_output_folder = join(output_folder, "3d_lowres_predictions") predict_from_folder(model_folder_name, input_folder, lowres_output_folder, folds, False, num_threads_preprocessing, num_threads_nifti_save, None, part_id, num_parts, not disable_tta, overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu, mixed_precision=not args.disable_mixed_precision, step_size=step_size) lowres_segmentations = lowres_output_folder torch.cuda.empty_cache() print("3d_lowres done") if model == "3d_cascade_fullres": trainer = cascade_trainer_class_name else: trainer = trainer_class_name model_folder_name = join(network_training_output_dir, model, task_name, trainer + "__" + args.plans_identifier) print("using model stored in ", model_folder_name) assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name import os import shutil if not os.path.exists(join(model_folder_name, "plans.pkl")): plans_file = join(preprocessing_output_dir, task_name, args.plans_identifier + "_plans_3D.pkl") shutil.copy(plans_file, join(model_folder_name, "plans.pkl")) _=get_default_configuration(model, task_name, trainer_class_name, args.plans_identifier) predict_from_folder(model_folder_name, input_folder, output_folder, folds, save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, part_id, num_parts, not disable_tta, overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu, mixed_precision=not args.disable_mixed_precision, step_size=step_size, checkpoint_name=args.chk) if __name__ == "__main__": main() ================================================ FILE: unetr_pp/inference/segmentation_export.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from copy import deepcopy from typing import Union, Tuple import numpy as np import SimpleITK as sitk from batchgenerators.augmentations.utils import resize_segmentation from unetr_pp.preprocessing.preprocessing import get_lowres_axis, get_do_separate_z, resample_data_or_seg from batchgenerators.utilities.file_and_folder_operations import * def save_segmentation_nifti_from_softmax(segmentation_softmax: Union[str, np.ndarray], out_fname: str, properties_dict: dict, order: int = 1, region_class_order: Tuple[Tuple[int]] = None, seg_postprogess_fn: callable = None, seg_postprocess_args: tuple = None, resampled_npz_fname: str = None, non_postprocessed_fname: str = None, force_separate_z: bool = None, interpolation_order_z: int = 0, verbose: bool = True): """ This is a utility for writing segmentations to nifto and npz. It requires the data to have been preprocessed by GenericPreprocessor because it depends on the property dictionary output (dct) to know the geometry of the original data. segmentation_softmax does not have to have the same size in pixels as the original data, it will be resampled to match that. This is generally useful because the spacings our networks operate on are most of the time not the native spacings of the image data. If seg_postprogess_fn is not None then seg_postprogess_fnseg_postprogess_fn(segmentation, *seg_postprocess_args) will be called before nifto export There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code.) We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray for segmentation_softmax and will handle this automatically :param segmentation_softmax: :param out_fname: :param properties_dict: :param order: :param region_class_order: :param seg_postprogess_fn: :param seg_postprocess_args: :param resampled_npz_fname: :param non_postprocessed_fname: :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always /never resample along z separately. Do not touch unless you know what you are doing :param interpolation_order_z: if separate z resampling is done then this is the order for resampling in z :param verbose: :return: """ if verbose: print("force_separate_z:", force_separate_z, "interpolation order:", order) if isinstance(segmentation_softmax, str): assert isfile(segmentation_softmax), "If isinstance(segmentation_softmax, str) then " \ "isfile(segmentation_softmax) must be True" del_file = deepcopy(segmentation_softmax) segmentation_softmax = np.load(segmentation_softmax) os.remove(del_file) # first resample, then put result into bbox of cropping, then save current_shape = segmentation_softmax.shape shape_original_after_cropping = properties_dict.get('size_after_cropping') shape_original_before_cropping = properties_dict.get('original_size_of_raw_data') # current_spacing = dct.get('spacing_after_resampling') # original_spacing = dct.get('original_spacing') if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original_after_cropping))]): if force_separate_z is None: if get_do_separate_z(properties_dict.get('original_spacing')): do_separate_z = True lowres_axis = get_lowres_axis(properties_dict.get('original_spacing')) elif get_do_separate_z(properties_dict.get('spacing_after_resampling')): do_separate_z = True lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling')) else: do_separate_z = False lowres_axis = None else: do_separate_z = force_separate_z if do_separate_z: lowres_axis = get_lowres_axis(properties_dict.get('original_spacing')) else: lowres_axis = None if lowres_axis is not None and len(lowres_axis) != 1: # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample # separately in the out of plane axis do_separate_z = False if verbose: print("separate z:", do_separate_z, "lowres axis", lowres_axis) seg_old_spacing = resample_data_or_seg(segmentation_softmax, shape_original_after_cropping, is_seg=False, axis=lowres_axis, order=order, do_separate_z=do_separate_z, cval=0, order_z=interpolation_order_z) # seg_old_spacing = resize_softmax_output(segmentation_softmax, shape_original_after_cropping, order=order) else: if verbose: print("no resampling necessary") seg_old_spacing = segmentation_softmax if resampled_npz_fname is not None: np.savez_compressed(resampled_npz_fname, softmax=seg_old_spacing.astype(np.float16)) # this is needed for ensembling if the nonlinearity is sigmoid if region_class_order is not None: properties_dict['regions_class_order'] = region_class_order save_pickle(properties_dict, resampled_npz_fname[:-4] + ".pkl") if region_class_order is None: seg_old_spacing = seg_old_spacing.argmax(0) else: seg_old_spacing_final = np.zeros(seg_old_spacing.shape[1:]) for i, c in enumerate(region_class_order): seg_old_spacing_final[seg_old_spacing[i] > 0.5] = c seg_old_spacing = seg_old_spacing_final bbox = properties_dict.get('crop_bbox') if bbox is not None: seg_old_size = np.zeros(shape_original_before_cropping) for c in range(3): bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c])) seg_old_size[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], bbox[2][0]:bbox[2][1]] = seg_old_spacing else: seg_old_size = seg_old_spacing if seg_postprogess_fn is not None: seg_old_size_postprocessed = seg_postprogess_fn(np.copy(seg_old_size), *seg_postprocess_args) else: seg_old_size_postprocessed = seg_old_size seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(np.uint8)) seg_resized_itk.SetSpacing(properties_dict['itk_spacing']) seg_resized_itk.SetOrigin(properties_dict['itk_origin']) seg_resized_itk.SetDirection(properties_dict['itk_direction']) sitk.WriteImage(seg_resized_itk, out_fname) if (non_postprocessed_fname is not None) and (seg_postprogess_fn is not None): seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8)) seg_resized_itk.SetSpacing(properties_dict['itk_spacing']) seg_resized_itk.SetOrigin(properties_dict['itk_origin']) seg_resized_itk.SetDirection(properties_dict['itk_direction']) sitk.WriteImage(seg_resized_itk, non_postprocessed_fname) def save_segmentation_nifti(segmentation, out_fname, dct, order=1, force_separate_z=None, order_z=0): """ faster and uses less ram than save_segmentation_nifti_from_softmax, but maybe less precise and also does not support softmax export (which is needed for ensembling). So it's a niche function that may be useful in some cases. :param segmentation: :param out_fname: :param dct: :param order: :param force_separate_z: :return: """ # suppress output print("force_separate_z:", force_separate_z, "interpolation order:", order) sys.stdout = open(os.devnull, 'w') if isinstance(segmentation, str): assert isfile(segmentation), "If isinstance(segmentation_softmax, str) then " \ "isfile(segmentation_softmax) must be True" del_file = deepcopy(segmentation) segmentation = np.load(segmentation) os.remove(del_file) # first resample, then put result into bbox of cropping, then save current_shape = segmentation.shape shape_original_after_cropping = dct.get('size_after_cropping') shape_original_before_cropping = dct.get('original_size_of_raw_data') # current_spacing = dct.get('spacing_after_resampling') # original_spacing = dct.get('original_spacing') if np.any(np.array(current_shape) != np.array(shape_original_after_cropping)): if order == 0: seg_old_spacing = resize_segmentation(segmentation, shape_original_after_cropping, 0, 0) else: if force_separate_z is None: if get_do_separate_z(dct.get('original_spacing')): do_separate_z = True lowres_axis = get_lowres_axis(dct.get('original_spacing')) elif get_do_separate_z(dct.get('spacing_after_resampling')): do_separate_z = True lowres_axis = get_lowres_axis(dct.get('spacing_after_resampling')) else: do_separate_z = False lowres_axis = None else: do_separate_z = force_separate_z if do_separate_z: lowres_axis = get_lowres_axis(dct.get('original_spacing')) else: lowres_axis = None print("separate z:", do_separate_z, "lowres axis", lowres_axis) seg_old_spacing = resample_data_or_seg(segmentation[None], shape_original_after_cropping, is_seg=True, axis=lowres_axis, order=order, do_separate_z=do_separate_z, cval=0, order_z=order_z)[0] else: seg_old_spacing = segmentation bbox = dct.get('crop_bbox') if bbox is not None: seg_old_size = np.zeros(shape_original_before_cropping) for c in range(3): bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c])) seg_old_size[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], bbox[2][0]:bbox[2][1]] = seg_old_spacing else: seg_old_size = seg_old_spacing seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8)) seg_resized_itk.SetSpacing(dct['itk_spacing']) seg_resized_itk.SetOrigin(dct['itk_origin']) seg_resized_itk.SetDirection(dct['itk_direction']) sitk.WriteImage(seg_resized_itk, out_fname) sys.stdout = sys.__stdout__ ================================================ FILE: unetr_pp/inference_acdc.py ================================================ import glob import os import SimpleITK as sitk import numpy as np from medpy.metric import binary from sklearn.neighbors import KDTree from scipy import ndimage import argparse def read_nii(path): itk_img=sitk.ReadImage(path) spacing=np.array(itk_img.GetSpacing()) return sitk.GetArrayFromImage(itk_img),spacing def dice(pred, label): if (pred.sum() + label.sum()) == 0: return 1 else: return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) def process_label(label): rv = label == 1 myo = label == 2 lv = label == 3 return rv,myo,lv ''' def hd(pred,gt): pred[pred > 0] = 1 gt[gt > 0] = 1 if pred.sum() > 0 and gt.sum()>0: dice = binary.dc(pred, gt) hd95 = binary.hd95(pred, gt) return dice, hd95 elif pred.sum() > 0 and gt.sum()==0: return 1, 0 else: return 0, 0 ''' def hd(pred,gt): #labelPred=sitk.GetImageFromArray(lP.astype(np.float32), isVector=False) #labelTrue=sitk.GetImageFromArray(lT.astype(np.float32), isVector=False) #hausdorffcomputer=sitk.HausdorffDistanceImageFilter() #hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5) #return hausdorffcomputer.GetAverageHausdorffDistance() if pred.sum() > 0 and gt.sum()>0: hd95 = binary.hd95(pred, gt) print(hd95) return hd95 else: return 0 def test(fold): label_path = './' pred_path = '/' label_list = sorted(glob.glob(os.path.join(label_path, '*nii.gz'))) infer_list = sorted(glob.glob(os.path.join(pred_path, '*nii.gz'))) print("loading success...") print(label_list) print(infer_list) Dice_rv=[] Dice_myo=[] Dice_lv=[] hd_rv=[] hd_myo=[] hd_lv=[] file=path + 'inferTs/'+fold if not os.path.exists(file): os.makedirs(file) fw = open(file+'/dice_pre.txt', 'w') for label_path,infer_path in zip(label_list,infer_list): print(label_path.split('/')[-1]) print(infer_path.split('/')[-1]) label,spacing= read_nii(label_path) infer,spacing= read_nii(infer_path) label_rv,label_myo,label_lv=process_label(label) infer_rv,infer_myo,infer_lv=process_label(infer) Dice_rv.append(dice(infer_rv,label_rv)) Dice_myo.append(dice(infer_myo,label_myo)) Dice_lv.append(dice(infer_lv,label_lv)) hd_rv.append(hd(infer_rv,label_rv)) hd_myo.append(hd(infer_myo,label_myo)) hd_lv.append(hd(infer_lv,label_lv)) fw.write('*'*20+'\n',) fw.write(infer_path.split('/')[-1]+'\n') fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1])) fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1])) fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1])) #fw.write('*'*20+'\n') fw.write('*'*20+'\n',) fw.write(infer_path.split('/')[-1]+'\n') fw.write('Dice_rv: {:.4f}\n'.format(Dice_rv[-1])) fw.write('Dice_myo: {:.4f}\n'.format(Dice_myo[-1])) fw.write('Dice_lv: {:.4f}\n'.format(Dice_lv[-1])) fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1])) fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1])) fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1])) fw.write('*'*20+'\n') #fw.write('*'*20+'\n') #fw.write('Mean_hd\n') #fw.write('hd_rv'+str(np.mean(hd_rv))+'\n') #fw.write('hd_myo'+str(np.mean(hd_myo))+'\n') #fw.write('hd_lv'+str(np.mean(hd_lv))+'\n') #fw.write('*'*20+'\n') fw.write('*'*20+'\n') fw.write('Mean_Dice\n') fw.write('Dice_rv'+str(np.mean(Dice_rv))+'\n') fw.write('Dice_myo'+str(np.mean(Dice_myo))+'\n') fw.write('Dice_lv'+str(np.mean(Dice_lv))+'\n') fw.write('Mean_HD\n') fw.write('HD_rv'+str(np.mean(hd_rv))+'\n') fw.write('HD_myo'+str(np.mean(hd_myo))+'\n') fw.write('HD_lv'+str(np.mean(hd_lv))+'\n') fw.write('*'*20+'\n') dsc=[] dsc.append(np.mean(Dice_rv)) dsc.append(np.mean(Dice_myo)) dsc.append(np.mean(Dice_lv)) avg_hd=[] avg_hd.append(np.mean(hd_rv)) avg_hd.append(np.mean(hd_myo)) avg_hd.append(np.mean(hd_lv)) fw.write('avg_hd:'+str(np.mean(avg_hd))+'\n') fw.write('DSC:'+str(np.mean(dsc))+'\n') fw.write('HD:'+str(np.mean(avg_hd))+'\n') print('done') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("fold", help="fold name") args = parser.parse_args() fold = args.fold test(fold) ================================================ FILE: unetr_pp/inference_synapse.py ================================================ import glob import os import SimpleITK as sitk import numpy as np import argparse from medpy import metric def read_nii(path): return sitk.GetArrayFromImage(sitk.ReadImage(path)) def dice(pred, label): if (pred.sum() + label.sum()) == 0: return 1 else: return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) def hd(pred,gt): if pred.sum() > 0 and gt.sum()>0: hd95 = metric.binary.hd95(pred, gt) return hd95 else: return 0 def process_label(label): spleen = label == 1 right_kidney = label == 2 left_kidney = label == 3 gallbladder = label == 4 liver = label == 6 stomach = label == 7 aorta = label == 8 pancreas = label == 11 return spleen,right_kidney,left_kidney,gallbladder,liver,stomach,aorta,pancreas def test(fold): label_path= None # Replace None by full path of "DATASET/unetr_pp_raw/unetr_pp_raw_data/Task002_Synapse/" infer_path = None # Replace None by full path of "output_synapse" label_list=sorted(glob.glob(os.path.join(label_path,'labelsTs','*nii.gz'))) infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz'))) print("loading success...") print(label_list) print(infer_list) Dice_spleen=[] Dice_right_kidney=[] Dice_left_kidney=[] Dice_gallbladder=[] Dice_liver=[] Dice_stomach=[] Dice_aorta=[] Dice_pancreas=[] hd_spleen=[] hd_right_kidney=[] hd_left_kidney=[] hd_gallbladder=[] hd_liver=[] hd_stomach=[] hd_aorta=[] hd_pancreas=[] file=infer_path + 'inferTs/'+fold if not os.path.exists(file): os.makedirs(file) fw = open(file+'/dice_pre.txt', 'a') for label_path,infer_path in zip(label_list,infer_list): print(label_path.split('/')[-1]) print(infer_path.split('/')[-1]) label,infer = read_nii(label_path),read_nii(infer_path) label_spleen,label_right_kidney,label_left_kidney,label_gallbladder,label_liver,label_stomach,label_aorta,label_pancreas=process_label(label) infer_spleen,infer_right_kidney,infer_left_kidney,infer_gallbladder,infer_liver,infer_stomach,infer_aorta,infer_pancreas=process_label(infer) Dice_spleen.append(dice(infer_spleen,label_spleen)) Dice_right_kidney.append(dice(infer_right_kidney,label_right_kidney)) Dice_left_kidney.append(dice(infer_left_kidney,label_left_kidney)) Dice_gallbladder.append(dice(infer_gallbladder,label_gallbladder)) Dice_liver.append(dice(infer_liver,label_liver)) Dice_stomach.append(dice(infer_stomach,label_stomach)) Dice_aorta.append(dice(infer_aorta,label_aorta)) Dice_pancreas.append(dice(infer_pancreas,label_pancreas)) hd_spleen.append(hd(infer_spleen,label_spleen)) hd_right_kidney.append(hd(infer_right_kidney,label_right_kidney)) hd_left_kidney.append(hd(infer_left_kidney,label_left_kidney)) hd_gallbladder.append(hd(infer_gallbladder,label_gallbladder)) hd_liver.append(hd(infer_liver,label_liver)) hd_stomach.append(hd(infer_stomach,label_stomach)) hd_aorta.append(hd(infer_aorta,label_aorta)) hd_pancreas.append(hd(infer_pancreas,label_pancreas)) fw.write('*'*20+'\n',) fw.write(infer_path.split('/')[-1]+'\n') fw.write('Dice_spleen: {:.4f}\n'.format(Dice_spleen[-1])) fw.write('Dice_right_kidney: {:.4f}\n'.format(Dice_right_kidney[-1])) fw.write('Dice_left_kidney: {:.4f}\n'.format(Dice_left_kidney[-1])) fw.write('Dice_gallbladder: {:.4f}\n'.format(Dice_gallbladder[-1])) fw.write('Dice_liver: {:.4f}\n'.format(Dice_liver[-1])) fw.write('Dice_stomach: {:.4f}\n'.format(Dice_stomach[-1])) fw.write('Dice_aorta: {:.4f}\n'.format(Dice_aorta[-1])) fw.write('Dice_pancreas: {:.4f}\n'.format(Dice_pancreas[-1])) fw.write('hd_spleen: {:.4f}\n'.format(hd_spleen[-1])) fw.write('hd_right_kidney: {:.4f}\n'.format(hd_right_kidney[-1])) fw.write('hd_left_kidney: {:.4f}\n'.format(hd_left_kidney[-1])) fw.write('hd_gallbladder: {:.4f}\n'.format(hd_gallbladder[-1])) fw.write('hd_liver: {:.4f}\n'.format(hd_liver[-1])) fw.write('hd_stomach: {:.4f}\n'.format(hd_stomach[-1])) fw.write('hd_aorta: {:.4f}\n'.format(hd_aorta[-1])) fw.write('hd_pancreas: {:.4f}\n'.format(hd_pancreas[-1])) dsc=[] HD=[] dsc.append(Dice_spleen[-1]) dsc.append((Dice_right_kidney[-1])) dsc.append(Dice_left_kidney[-1]) dsc.append(np.mean(Dice_gallbladder[-1])) dsc.append(np.mean(Dice_liver[-1])) dsc.append(np.mean(Dice_stomach[-1])) dsc.append(np.mean(Dice_aorta[-1])) dsc.append(np.mean(Dice_pancreas[-1])) fw.write('DSC:'+str(np.mean(dsc))+'\n') HD.append(hd_spleen[-1]) HD.append(hd_right_kidney[-1]) HD.append(hd_left_kidney[-1]) HD.append(hd_gallbladder[-1]) HD.append(hd_liver[-1]) HD.append(hd_stomach[-1]) HD.append(hd_aorta[-1]) HD.append(hd_pancreas[-1]) fw.write('hd:'+str(np.mean(HD))+'\n') fw.write('*'*20+'\n') fw.write('Mean_Dice\n') fw.write('Dice_spleen'+str(np.mean(Dice_spleen))+'\n') fw.write('Dice_right_kidney'+str(np.mean(Dice_right_kidney))+'\n') fw.write('Dice_left_kidney'+str(np.mean(Dice_left_kidney))+'\n') fw.write('Dice_gallbladder'+str(np.mean(Dice_gallbladder))+'\n') fw.write('Dice_liver'+str(np.mean(Dice_liver))+'\n') fw.write('Dice_stomach'+str(np.mean(Dice_stomach))+'\n') fw.write('Dice_aorta'+str(np.mean(Dice_aorta))+'\n') fw.write('Dice_pancreas'+str(np.mean(Dice_pancreas))+'\n') fw.write('Mean_hd\n') fw.write('hd_spleen'+str(np.mean(hd_spleen))+'\n') fw.write('hd_right_kidney'+str(np.mean(hd_right_kidney))+'\n') fw.write('hd_left_kidney'+str(np.mean(hd_left_kidney))+'\n') fw.write('hd_gallbladder'+str(np.mean(hd_gallbladder))+'\n') fw.write('hd_liver'+str(np.mean(hd_liver))+'\n') fw.write('hd_stomach'+str(np.mean(hd_stomach))+'\n') fw.write('hd_aorta'+str(np.mean(hd_aorta))+'\n') fw.write('hd_pancreas'+str(np.mean(hd_pancreas))+'\n') fw.write('*'*20+'\n') dsc=[] dsc.append(np.mean(Dice_spleen)) dsc.append(np.mean(Dice_right_kidney)) dsc.append(np.mean(Dice_left_kidney)) dsc.append(np.mean(Dice_gallbladder)) dsc.append(np.mean(Dice_liver)) dsc.append(np.mean(Dice_stomach)) dsc.append(np.mean(Dice_aorta)) dsc.append(np.mean(Dice_pancreas)) fw.write('dsc:'+str(np.mean(dsc))+'\n') HD=[] HD.append(np.mean(hd_spleen)) HD.append(np.mean(hd_right_kidney)) HD.append(np.mean(hd_left_kidney)) HD.append(np.mean(hd_gallbladder)) HD.append(np.mean(hd_liver)) HD.append(np.mean(hd_stomach)) HD.append(np.mean(hd_aorta)) HD.append(np.mean(hd_pancreas)) fw.write('hd:'+str(np.mean(HD))+'\n') print('done') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("fold", help="fold name") args = parser.parse_args() fold=args.fold test(fold) ================================================ FILE: unetr_pp/inference_tumor.py ================================================ import glob import os import SimpleITK as sitk import numpy as np import argparse from medpy.metric import binary def read_nii(path): return sitk.GetArrayFromImage(sitk.ReadImage(path)) def new_dice(pred,label): tp_hard = np.sum((pred == 1).astype(np.float) * (label == 1).astype(np.float)) fp_hard = np.sum((pred == 1).astype(np.float) * (label != 1).astype(np.float)) fn_hard = np.sum((pred != 1).astype(np.float) * (label == 1).astype(np.float)) return 2*tp_hard/(2*tp_hard+fp_hard+fn_hard) def dice(pred, label): if (pred.sum() + label.sum()) == 0: return 1 else: return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) def hd(pred,gt): if pred.sum() > 0 and gt.sum()>0: hd95 = binary.hd95(pred, gt) return hd95 else: return 0 def process_label(label): net = label == 2 ed = label == 1 et = label == 3 ET=et TC=net+et WT=net+et+ed ED= ed NET=net return ET,TC,WT,ED,NET def test(fold): #path='./' path = None # Replace None by the full path of : unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/" label_list=sorted(glob.glob(os.path.join(path,'labelsTs','*nii.gz'))) infer_path = None # Replace None by the full path of : unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/" infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz'))) print("loading success...") Dice_et=[] Dice_tc=[] Dice_wt=[] Dice_ed=[] Dice_net=[] HD_et=[] HD_tc=[] HD_wt=[] HD_ed=[] HD_net=[] file=infer_path + 'inferTs/'+fold if not os.path.exists(file): os.makedirs(file) fw = open(file+'/dice_five.txt', 'w') for label_path,infer_path in zip(label_list,infer_list): print(label_path.split('/')[-1]) print(infer_path.split('/')[-1]) label,infer = read_nii(label_path),read_nii(infer_path) label_et,label_tc,label_wt,label_ed,label_net=process_label(label) infer_et,infer_tc,infer_wt,infer_ed,infer_net=process_label(infer) Dice_et.append(dice(infer_et,label_et)) Dice_tc.append(dice(infer_tc,label_tc)) Dice_wt.append(dice(infer_wt,label_wt)) Dice_ed.append(dice(infer_ed,label_ed)) Dice_net.append(dice(infer_net,label_net)) HD_et.append(hd(infer_et,label_et)) HD_tc.append(hd(infer_tc,label_tc)) HD_wt.append(hd(infer_wt,label_wt)) HD_ed.append(hd(infer_ed,label_ed)) HD_net.append(hd(infer_net,label_net)) fw.write('*'*20+'\n',) fw.write(infer_path.split('/')[-1]+'\n') fw.write('hd_et: {:.4f}\n'.format(HD_et[-1])) fw.write('hd_tc: {:.4f}\n'.format(HD_tc[-1])) fw.write('hd_wt: {:.4f}\n'.format(HD_wt[-1])) fw.write('hd_ed: {:.4f}\n'.format(HD_ed[-1])) fw.write('hd_net: {:.4f}\n'.format(HD_net[-1])) fw.write('*'*20+'\n',) fw.write('Dice_et: {:.4f}\n'.format(Dice_et[-1])) fw.write('Dice_tc: {:.4f}\n'.format(Dice_tc[-1])) fw.write('Dice_wt: {:.4f}\n'.format(Dice_wt[-1])) fw.write('Dice_ed: {:.4f}\n'.format(Dice_ed[-1])) fw.write('Dice_net: {:.4f}\n'.format(Dice_net[-1])) #print('dice_et: {:.4f}'.format(np.mean(Dice_et))) #print('dice_tc: {:.4f}'.format(np.mean(Dice_tc))) #print('dice_wt: {:.4f}'.format(np.mean(Dice_wt))) dsc=[] avg_hd=[] dsc.append(np.mean(Dice_et)) dsc.append(np.mean(Dice_tc)) dsc.append(np.mean(Dice_wt)) dsc.append(np.mean(Dice_ed)) dsc.append(np.mean(Dice_net)) avg_hd.append(np.mean(HD_et)) avg_hd.append(np.mean(HD_tc)) avg_hd.append(np.mean(HD_wt)) avg_hd.append(np.mean(HD_ed)) avg_hd.append(np.mean(HD_net)) fw.write('Dice_et'+str(np.mean(Dice_et))+' '+'\n') fw.write('Dice_tc'+str(np.mean(Dice_tc))+' '+'\n') fw.write('Dice_wt'+str(np.mean(Dice_wt))+' '+'\n') fw.write('Dice_ed'+str(np.mean(Dice_ed))+' '+'\n') fw.write('Dice_net'+str(np.mean(Dice_net))+' '+'\n') fw.write('HD_et'+str(np.mean(HD_et))+' '+'\n') fw.write('HD_tc'+str(np.mean(HD_tc))+' '+'\n') fw.write('HD_wt'+str(np.mean(HD_wt))+' '+'\n') fw.write('HD_ed'+str(np.mean(HD_ed))+' '+'\n') fw.write('HD_net'+str(np.mean(HD_net))+' '+'\n') fw.write('Dice'+str(np.mean(dsc))+' '+'\n') fw.write('HD'+str(np.mean(avg_hd))+' '+'\n') #print('Dice'+str(np.mean(dsc))+' '+'\n') #print('HD'+str(np.mean(avg_hd))+' '+'\n') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("fold", help="fold name") args = parser.parse_args() fold=args.fold test(fold) ================================================ FILE: unetr_pp/network_architecture/README.md ================================================ You can change batch size, input data size ``` https://github.com/282857341/nnFormer/blob/6e36d76f9b7d0bea522e1cd05adf502ba85480e6/nnformer/run/default_configuration.py#L49-L68 ``` ================================================ FILE: unetr_pp/network_architecture/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/network_architecture/acdc/__init__.py ================================================ ================================================ FILE: unetr_pp/network_architecture/acdc/model_components.py ================================================ from torch import nn from timm.models.layers import trunc_normal_ from typing import Sequence, Tuple, Union from monai.networks.layers.utils import get_norm_layer from monai.utils import optional_import from unetr_pp.network_architecture.layers import LayerNorm from unetr_pp.network_architecture.acdc.transformerblock import TransformerBlock from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock einops, _ = optional_import("einops") class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[16 * 40 * 40, 8 * 20 * 20, 4 * 10 * 10, 2 * 5 * 5],dims=[32, 64, 128, 256], proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(1, 4, 4), stride=(1, 4, 4), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) self.downsample_layers.append(stem_layer) for i in range(3): downsample_layer = nn.Sequential( get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks for i in range(4): stage_blocks = [] for j in range(depths[i]): stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, dropout_rate=transformer_dropout_rate, pos_embed=True)) self.stages.append(nn.Sequential(*stage_blocks)) self.hidden_states = [] self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (LayerNorm, nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): hidden_states = [] x = self.downsample_layers[0](x) x = self.stages[0](x) hidden_states.append(x) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i == 3: # Reshape the output of the last stage x = einops.rearrange(x, "b c h w d -> b (h w d) c") hidden_states.append(x) return x, hidden_states def forward(self, x): x, hidden_states = self.forward_features(x) return x, hidden_states class UnetrUpBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], proj_size: int = 64, num_heads: int = 4, out_size: int = 0, depth: int = 3, conv_decoder: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of heads inside each EPA module. out_size: spatial size for each decoder. depth: number of blocks for the current decoder stage. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block # (see suppl. material in the paper) if conv_decoder == True: self.decoder_block.append( UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, )) else: stage_blocks = [] for j in range(depth): stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, dropout_rate=0.1, pos_embed=True)) self.decoder_block.append(nn.Sequential(*stage_blocks)) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, inp, skip): out = self.transp_conv(inp) out = out + skip out = self.decoder_block[0](out) return out ================================================ FILE: unetr_pp/network_architecture/acdc/transformerblock.py ================================================ import torch.nn as nn import torch from unetr_pp.network_architecture.dynunet_block import UnetResBlock class TransformerBlock(nn.Module): """ A transformer block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, input_size: int, hidden_size: int, proj_size: int, num_heads: int, dropout_rate: float = 0.0, pos_embed=False, ) -> None: """ Args: input_size: the size of the input for each stage. hidden_size: dimension of hidden layer. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. pos_embed: bool argument to determine if positional embedding is used. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: print("Hidden size is ", hidden_size) print("Num heads is ", num_heads) raise ValueError("hidden_size should be divisible by num_heads.") self.norm = nn.LayerNorm(hidden_size) self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) self.pos_embed = None if pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) def forward(self, x): B, C, H, W, D = x.shape x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: x = x + self.pos_embed attn = x + self.gamma * self.epa_block(self.norm(x)) attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) attn = self.conv51(attn_skip) attn = self.conv52(attn) x = attn_skip + self.conv8(attn) return x class EPA(nn.Module): """ Efficient Paired Attention Block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): super().__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) # E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension self.E = nn.Linear(input_size, proj_size) self.F = nn.Linear(input_size, proj_size) self.attn_drop = nn.Dropout(channel_attn_drop) self.attn_drop_2 = nn.Dropout(spatial_attn_drop) self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) def forward(self, x): B, N, C = x.shape qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) qkvv = qkvv.permute(2, 0, 3, 1, 4) q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] q_shared = q_shared.transpose(-2, -1) k_shared = k_shared.transpose(-2, -1) v_CA = v_CA.transpose(-2, -1) v_SA = v_SA.transpose(-2, -1) k_shared_projected = self.E(k_shared) v_SA_projected = self.F(v_SA) q_shared = torch.nn.functional.normalize(q_shared, dim=-1) k_shared = torch.nn.functional.normalize(k_shared, dim=-1) attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature attn_CA = attn_CA.softmax(dim=-1) attn_CA = self.attn_drop(attn_CA) x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 attn_SA = attn_SA.softmax(dim=-1) attn_SA = self.attn_drop_2(attn_SA) x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) # Concat fusion x_SA = self.out_proj(x_SA) x_CA = self.out_proj2(x_CA) x = torch.cat((x_SA, x_CA), dim=-1) return x @torch.jit.ignore def no_weight_decay(self): return {'temperature', 'temperature2'} ================================================ FILE: unetr_pp/network_architecture/acdc/unetr_pp_acdc.py ================================================ from torch import nn from typing import Tuple, Union from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock from unetr_pp.network_architecture.acdc.model_components import UnetrPPEncoder, UnetrUpBlock class UNETR_PP(SegmentationNetwork): """ UNETR++ based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, in_channels: int, out_channels: int, feature_size: int = 16, hidden_size: int = 256, num_heads: int = 4, pos_embed: str = "perceptron", norm_name: Union[Tuple, str] = "instance", dropout_rate: float = 0.0, depths=None, dims=None, conv_op=nn.Conv3d, do_ds=True, ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. img_size: dimension of input image. feature_size: dimension of network feature size. hidden_size: dimensions of the last encoder. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. dropout_rate: faction of the input units to drop. depths: number of blocks for each stage. dims: number of channel maps for the stages. conv_op: type of convolution operation. do_ds: use deep supervision to compute the loss. """ super().__init__() if depths is None: depths = [3, 3, 3, 3] self.do_ds = do_ds self.conv_op = conv_op self.num_classes = out_channels if not (0 <= dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") self.feat_size = (2, 5, 5,) self.hidden_size = hidden_size self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) self.encoder1 = UnetResBlock( spatial_dims=3, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, ) self.decoder5 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=4*10*10, ) self.decoder4 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=8*20*20, ) self.decoder3 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=16*40*40, ) self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=(1, 4, 4), norm_name=norm_name, out_size=16*160*160, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) if self.do_ds: self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) x = x.permute(0, 4, 1, 2, 3).contiguous() return x def forward(self, x_in): x_output, hidden_states = self.unetr_pp_encoder(x_in) convBlock = self.encoder1(x_in) # Four encoders enc1 = hidden_states[0] enc2 = hidden_states[1] enc3 = hidden_states[2] enc4 = hidden_states[3] # Four decoders dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) dec3 = self.decoder5(dec4, enc3) dec2 = self.decoder4(dec3, enc2) dec1 = self.decoder3(dec2, enc1) out = self.decoder2(dec1, convBlock) if self.do_ds: logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] else: logits = self.out1(out) return logits ================================================ FILE: unetr_pp/network_architecture/dynunet_block.py ================================================ from typing import Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn from monai.networks.blocks.convolutions import Convolution from monai.networks.layers.factories import Act, Norm from monai.networks.layers.utils import get_act_layer, get_norm_layer class UnetResBlock(nn.Module): """ A skip-connection based module that can be used for DynUNet, based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.downsample = in_channels != out_channels stride_np = np.atleast_1d(stride) if not np.all(stride_np == 1): self.downsample = True if self.downsample: self.conv3 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True ) self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) def forward(self, inp): residual = inp out = self.conv1(inp) out = self.norm1(out) out = self.lrelu(out) out = self.conv2(out) out = self.norm2(out) if hasattr(self, "conv3"): residual = self.conv3(residual) if hasattr(self, "norm3"): residual = self.norm3(residual) out += residual out = self.lrelu(out) return out class UnetBasicBlock(nn.Module): """ A CNN module module that can be used for DynUNet, based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, ): super().__init__() self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, dropout=dropout, conv_only=True, ) self.conv2 = get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True ) self.lrelu = get_act_layer(name=act_name) self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) def forward(self, inp): out = self.conv1(inp) out = self.norm1(out) out = self.lrelu(out) out = self.conv2(out) out = self.norm2(out) out = self.lrelu(out) return out class UnetUpBlock(nn.Module): """ An upsampling module that can be used for DynUNet, based on: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. trans_bias: transposed convolution bias. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), dropout: Optional[Union[Tuple, str, float]] = None, trans_bias: bool = False, ): super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, dropout=dropout, bias=trans_bias, conv_only=True, is_transposed=True, ) self.conv_block = UnetBasicBlock( spatial_dims, out_channels + out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, norm_name=norm_name, act_name=act_name, ) def forward(self, inp, skip): # number of channels for skip should equals to out_channels out = self.transp_conv(inp) out = torch.cat((out, skip), dim=1) out = self.conv_block(out) return out class UnetOutBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None ): super().__init__() self.conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True ) def forward(self, inp): return self.conv(inp) def get_conv_layer( spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3, stride: Union[Sequence[int], int] = 1, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = False, conv_only: bool = True, is_transposed: bool = False, ): padding = get_padding(kernel_size, stride) output_padding = None if is_transposed: output_padding = get_output_padding(kernel_size, stride, padding) return Convolution( spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, act=act, norm=norm, dropout=dropout, bias=bias, conv_only=conv_only, is_transposed=is_transposed, padding=padding, output_padding=output_padding, ) def get_padding( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) padding_np = (kernel_size_np - stride_np + 1) / 2 if np.min(padding_np) < 0: raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") padding = tuple(int(p) for p in padding_np) return padding if len(padding) > 1 else padding[0] def get_output_padding( kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] ) -> Union[Tuple[int, ...], int]: kernel_size_np = np.atleast_1d(kernel_size) stride_np = np.atleast_1d(stride) padding_np = np.atleast_1d(padding) out_padding_np = 2 * padding_np + stride_np - kernel_size_np if np.min(out_padding_np) < 0: raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") out_padding = tuple(int(p) for p in out_padding_np) return out_padding if len(out_padding) > 1 else out_padding[0] ================================================ FILE: unetr_pp/network_architecture/generic_UNet.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from unetr_pp.utilities.nd_softmax import softmax_helper from torch import nn import torch import numpy as np from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork import torch.nn.functional class ConvDropoutNormNonlin(nn.Module): """ fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad. """ def __init__(self, input_channels, output_channels, conv_op=nn.Conv2d, conv_kwargs=None, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None): super(ConvDropoutNormNonlin, self).__init__() if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} if conv_kwargs is None: conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} self.nonlin_kwargs = nonlin_kwargs self.nonlin = nonlin self.dropout_op = dropout_op self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.conv_kwargs = conv_kwargs self.conv_op = conv_op self.norm_op = norm_op self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ 'p'] > 0: self.dropout = self.dropout_op(**self.dropout_op_kwargs) else: self.dropout = None self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) self.lrelu = self.nonlin(**self.nonlin_kwargs) def forward(self, x): x = self.conv(x) if self.dropout is not None: x = self.dropout(x) return self.lrelu(self.instnorm(x)) class ConvDropoutNonlinNorm(ConvDropoutNormNonlin): def forward(self, x): x = self.conv(x) if self.dropout is not None: x = self.dropout(x) return self.instnorm(self.lrelu(x)) class StackedConvLayers(nn.Module): def __init__(self, input_feature_channels, output_feature_channels, num_convs, conv_op=nn.Conv2d, conv_kwargs=None, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin): ''' stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers :param input_feature_channels: :param output_feature_channels: :param num_convs: :param dilation: :param kernel_size: :param padding: :param dropout: :param initial_stride: :param conv_op: :param norm_op: :param dropout_op: :param inplace: :param neg_slope: :param norm_affine: :param conv_bias: ''' self.input_channels = input_feature_channels self.output_channels = output_feature_channels if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} if conv_kwargs is None: conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} self.nonlin_kwargs = nonlin_kwargs self.nonlin = nonlin self.dropout_op = dropout_op self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.conv_kwargs = conv_kwargs self.conv_op = conv_op self.norm_op = norm_op if first_stride is not None: self.conv_kwargs_first_conv = deepcopy(conv_kwargs) self.conv_kwargs_first_conv['stride'] = first_stride else: self.conv_kwargs_first_conv = conv_kwargs super(StackedConvLayers, self).__init__() self.blocks = nn.Sequential( *([basic_block(input_feature_channels, output_feature_channels, self.conv_op, self.conv_kwargs_first_conv, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs)] + [basic_block(output_feature_channels, output_feature_channels, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) def forward(self, x): return self.blocks(x) def print_module_training_status(module): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, nn.BatchNorm1d): print(str(module), module.training) class Upsample(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): super(Upsample, self).__init__() self.align_corners = align_corners self.mode = mode self.scale_factor = scale_factor self.size = size def forward(self, x): return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) class Generic_UNet(SegmentationNetwork): DEFAULT_BATCH_SIZE_3D = 2 DEFAULT_PATCH_SIZE_3D = (64, 192, 160) SPACING_FACTOR_BETWEEN_STAGES = 2 BASE_NUM_FEATURES_3D = 30 MAX_NUMPOOL_3D = 999 MAX_NUM_FILTERS_3D = 320 DEFAULT_PATCH_SIZE_2D = (256, 256) BASE_NUM_FEATURES_2D = 30 DEFAULT_BATCH_SIZE_2D = 50 MAX_NUMPOOL_2D = 999 MAX_FILTERS_2D = 480 use_this_for_batch_size_computation_2D = 19739648 use_this_for_batch_size_computation_3D = 520000000 # 505789440 def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, conv_kernel_sizes=None, upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, max_num_features=None, basic_block=ConvDropoutNormNonlin, seg_output_use_bias=False): """ basically more flexible than v1, architecture is the same Does this look complicated? Nah bro. Functionality > usability This does everything you need, including world peace. Questions? -> f.isensee@dkfz.de """ super(Generic_UNet, self).__init__() self.convolutional_upsampling = convolutional_upsampling self.convolutional_pooling = convolutional_pooling self.upscale_logits = upscale_logits if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} self.nonlin = nonlin self.nonlin_kwargs = nonlin_kwargs self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.weightInitializer = weightInitializer self.conv_op = conv_op self.norm_op = norm_op self.dropout_op = dropout_op self.num_classes = num_classes self.final_nonlin = final_nonlin self._deep_supervision = deep_supervision self.do_ds = deep_supervision if conv_op == nn.Conv2d: upsample_mode = 'bilinear' pool_op = nn.MaxPool2d transpconv = nn.ConvTranspose2d if pool_op_kernel_sizes is None: pool_op_kernel_sizes = [(2, 2)] * num_pool if conv_kernel_sizes is None: conv_kernel_sizes = [(3, 3)] * (num_pool + 1) elif conv_op == nn.Conv3d: upsample_mode = 'trilinear' pool_op = nn.MaxPool3d transpconv = nn.ConvTranspose3d if pool_op_kernel_sizes is None: pool_op_kernel_sizes = [(2, 2, 2)] * num_pool if conv_kernel_sizes is None: conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) else: raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) self.pool_op_kernel_sizes = pool_op_kernel_sizes self.conv_kernel_sizes = conv_kernel_sizes self.conv_pad_sizes = [] for krnl in self.conv_kernel_sizes: self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) if max_num_features is None: if self.conv_op == nn.Conv3d: self.max_num_features = self.MAX_NUM_FILTERS_3D else: self.max_num_features = self.MAX_FILTERS_2D else: self.max_num_features = max_num_features self.conv_blocks_context = [] self.conv_blocks_localization = [] self.td = [] self.tu = [] self.seg_outputs = [] output_features = base_num_features input_features = input_channels for d in range(num_pool): # determine the first stride if d != 0 and self.convolutional_pooling: first_stride = pool_op_kernel_sizes[d - 1] else: first_stride = None self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] self.conv_kwargs['padding'] = self.conv_pad_sizes[d] # add convolutions self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, first_stride, basic_block=basic_block)) if not self.convolutional_pooling: self.td.append(pool_op(pool_op_kernel_sizes[d])) input_features = output_features output_features = int(np.round(output_features * feat_map_mul_on_downscale)) output_features = min(output_features, self.max_num_features) # now the bottleneck. # determine the first stride if self.convolutional_pooling: first_stride = pool_op_kernel_sizes[-1] else: first_stride = None # the output of the last conv must match the number of features from the skip connection if we are not using # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be # done by the transposed conv if self.convolutional_upsampling: final_num_features = output_features else: final_num_features = self.conv_blocks_context[-1].output_channels self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] self.conv_blocks_context.append(nn.Sequential( StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, first_stride, basic_block=basic_block), StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block))) # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here if not dropout_in_localization: old_dropout_p = self.dropout_op_kwargs['p'] self.dropout_op_kwargs['p'] = 0.0 # now lets build the localization pathway for u in range(num_pool): nfeatures_from_down = final_num_features nfeatures_from_skip = self.conv_blocks_context[ -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 n_features_after_tu_and_concat = nfeatures_from_skip * 2 # the first conv reduces the number of features to match those of skip # the following convs work on that number of features # if not convolutional upsampling then the final conv reduces the num of features again if u != num_pool - 1 and not self.convolutional_upsampling: final_num_features = self.conv_blocks_context[-(3 + u)].output_channels else: final_num_features = nfeatures_from_skip if not self.convolutional_upsampling: self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) else: self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], pool_op_kernel_sizes[-(u + 1)], bias=False)) self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)] self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)] self.conv_blocks_localization.append(nn.Sequential( StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block) )) for ds in range(len(self.conv_blocks_localization)): self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes, 1, 1, 0, 1, 1, seg_output_use_bias)) self.upscale_logits_ops = [] cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] for usl in range(num_pool - 1): if self.upscale_logits: self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]), mode=upsample_mode)) else: self.upscale_logits_ops.append(lambda x: x) if not dropout_in_localization: self.dropout_op_kwargs['p'] = old_dropout_p # register all modules properly self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) self.td = nn.ModuleList(self.td) self.tu = nn.ModuleList(self.tu) self.seg_outputs = nn.ModuleList(self.seg_outputs) if self.upscale_logits: self.upscale_logits_ops = nn.ModuleList( self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here if self.weightInitializer is not None: self.apply(self.weightInitializer) # self.apply(print_module_training_status) def forward(self, x): # print(x.shape) skips = [] seg_outputs = [] for d in range(len(self.conv_blocks_context) - 1): x = self.conv_blocks_context[d](x) skips.append(x) if not self.convolutional_pooling: x = self.td[d](x) x = self.conv_blocks_context[-1](x) for u in range(len(self.tu)): x = self.tu[u](x) x = torch.cat((x, skips[-(u + 1)]), dim=1) x = self.conv_blocks_localization[u](x) seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x))) if self._deep_supervision and self.do_ds: return tuple([seg_outputs[-1]] + [i(j) for i, j in zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])]) else: return seg_outputs[-1] @staticmethod def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features, num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False, conv_per_stage=2): """ This only applies for num_conv_per_stage and convolutional_upsampling=True not real vram consumption. just a constant term to which the vram consumption will be approx proportional (+ offset for parameter storage) :param deep_supervision: :param patch_size: :param num_pool_per_axis: :param base_num_features: :param max_num_features: :param num_modalities: :param num_classes: :param pool_op_kernel_sizes: :return: """ if not isinstance(num_pool_per_axis, np.ndarray): num_pool_per_axis = np.array(num_pool_per_axis) npool = len(pool_op_kernel_sizes) map_size = np.array(patch_size) tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features + num_modalities * np.prod(map_size, dtype=np.int64) + num_classes * np.prod(map_size, dtype=np.int64)) num_feat = base_num_features for p in range(npool): for pi in range(len(num_pool_per_axis)): map_size[pi] /= pool_op_kernel_sizes[p][pi] num_feat = min(num_feat * 2, max_num_features) num_blocks = (conv_per_stage * 2 + 1) if p < ( npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat if deep_supervision and p < (npool - 2): tmp += np.prod(map_size, dtype=np.int64) * num_classes # print(p, map_size, num_feat, tmp) return tmp ================================================ FILE: unetr_pp/network_architecture/initialization.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch import nn class InitWeights_He(object): def __init__(self, neg_slope=1e-2): self.neg_slope = neg_slope def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) class InitWeights_XavierUniform(object): def __init__(self, gain=1): self.gain = gain def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): module.weight = nn.init.xavier_uniform_(module.weight, self.gain) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) ================================================ FILE: unetr_pp/network_architecture/layers.py ================================================ import torch import torch.nn.functional as F from torch import nn import math class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class PositionalEncodingFourier(nn.Module): def __init__(self, hidden_dim=32, dim=768, temperature=10000): super().__init__() self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) self.scale = 2 * math.pi self.temperature = temperature self.hidden_dim = hidden_dim self.dim = dim def forward(self, B, H, W): mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = self.token_projection(pos) return pos ================================================ FILE: unetr_pp/network_architecture/lung/__init__.py ================================================ ================================================ FILE: unetr_pp/network_architecture/lung/model_components.py ================================================ from torch import nn from timm.models.layers import trunc_normal_ from typing import Sequence, Tuple, Union from monai.networks.layers.utils import get_norm_layer from monai.utils import optional_import from unetr_pp.network_architecture.layers import LayerNorm from unetr_pp.network_architecture.lung.transformerblock import TransformerBlock from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock einops, _ = optional_import("einops") #input_size=[28 * 24 * 20, 14 * 12 * 10, 7 * 6 * 5, ],dims=[32, 64, 128, 256] class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32*48*48, 16 * 24 * 24, 8 * 12 * 12, 4*6*6],dims=[32, 64, 128, 256], proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(1, 4, 4), stride=(1, 4, 4), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) self.downsample_layers.append(stem_layer) for i in range(3): downsample_layer = nn.Sequential( get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks for i in range(4): stage_blocks = [] for j in range(depths[i]): #print(" i ", i) #print("inputtt size ",input_size[i]) stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, dropout_rate=transformer_dropout_rate, pos_embed=True)) self.stages.append(nn.Sequential(*stage_blocks)) self.hidden_states = [] self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (LayerNorm, nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): hidden_states = [] x = self.downsample_layers[0](x) x = self.stages[0](x) hidden_states.append(x) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i == 3: # Reshape the output of the last stage x = einops.rearrange(x, "b c h w d -> b (h w d) c") hidden_states.append(x) return x, hidden_states def forward(self, x): x, hidden_states = self.forward_features(x) return x, hidden_states class UnetrUpBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], proj_size: int = 64, num_heads: int = 4, out_size: int = 0, depth: int = 3, conv_decoder: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of heads inside each EPA module. out_size: spatial size for each decoder. depth: number of blocks for the current decoder stage. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block # (see suppl. material in the paper) if conv_decoder == True: self.decoder_block.append( UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, )) else: stage_blocks = [] for j in range(depth): stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, dropout_rate=0.1, pos_embed=True)) self.decoder_block.append(nn.Sequential(*stage_blocks)) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, inp, skip): out = self.transp_conv(inp) out = out + skip out = self.decoder_block[0](out) return out ================================================ FILE: unetr_pp/network_architecture/lung/transformerblock.py ================================================ import torch.nn as nn import torch from unetr_pp.network_architecture.dynunet_block import UnetResBlock class TransformerBlock(nn.Module): """ A transformer block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, input_size: int, hidden_size: int, proj_size: int, num_heads: int, dropout_rate: float = 0.0, pos_embed=False, ) -> None: """ Args: input_size: the size of the input for each stage. hidden_size: dimension of hidden layer. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. pos_embed: bool argument to determine if positional embedding is used. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: print("Hidden size is ", hidden_size) print("Num heads is ", num_heads) raise ValueError("hidden_size should be divisible by num_heads.") self.norm = nn.LayerNorm(hidden_size) self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) self.pos_embed = None if pos_embed: #print("input size", input_size) self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) def forward(self, x): B, C, H, W, D = x.shape #print("XSHAPE ",x.shape) x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: #print ("x",x.shape) #print("pos_embed",self.pos_embed.shape) x = x + self.pos_embed attn = x + self.gamma * self.epa_block(self.norm(x)) attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) attn = self.conv51(attn_skip) attn = self.conv52(attn) x = attn_skip + self.conv8(attn) return x class EPA(nn.Module): """ Efficient Paired Attention Block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): super().__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) # E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension self.E = nn.Linear(input_size, proj_size) self.F = nn.Linear(input_size, proj_size) self.attn_drop = nn.Dropout(channel_attn_drop) self.attn_drop_2 = nn.Dropout(spatial_attn_drop) self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) def forward(self, x): B, N, C = x.shape #print("The shape in EPA ", self.E.shape) qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) qkvv = qkvv.permute(2, 0, 3, 1, 4) q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] q_shared = q_shared.transpose(-2, -1) k_shared = k_shared.transpose(-2, -1) v_CA = v_CA.transpose(-2, -1) v_SA = v_SA.transpose(-2, -1) k_shared_projected = self.E(k_shared) v_SA_projected = self.F(v_SA) q_shared = torch.nn.functional.normalize(q_shared, dim=-1) k_shared = torch.nn.functional.normalize(k_shared, dim=-1) attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature attn_CA = attn_CA.softmax(dim=-1) attn_CA = self.attn_drop(attn_CA) x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 attn_SA = attn_SA.softmax(dim=-1) attn_SA = self.attn_drop_2(attn_SA) x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) # Concat fusion x_SA = self.out_proj(x_SA) x_CA = self.out_proj2(x_CA) x = torch.cat((x_SA, x_CA), dim=-1) return x @torch.jit.ignore def no_weight_decay(self): return {'temperature', 'temperature2'} ================================================ FILE: unetr_pp/network_architecture/lung/unetr_pp_lung.py ================================================ from torch import nn from typing import Tuple, Union from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock from unetr_pp.network_architecture.lung.model_components import UnetrPPEncoder, UnetrUpBlock class UNETR_PP(SegmentationNetwork): """ UNETR++ based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, in_channels: int, out_channels: int, feature_size: int = 16, hidden_size: int = 256, num_heads: int = 4, pos_embed: str = "perceptron", norm_name: Union[Tuple, str] = "instance", dropout_rate: float = 0.0, depths=None, dims=None, conv_op=nn.Conv3d, do_ds=True, ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. img_size: dimension of input image. feature_size: dimension of network feature size. hidden_size: dimensions of the last encoder. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. dropout_rate: faction of the input units to drop. depths: number of blocks for each stage. dims: number of channel maps for the stages. conv_op: type of convolution operation. do_ds: use deep supervision to compute the loss. """ super().__init__() if depths is None: depths = [3, 3, 3, 3] self.do_ds = do_ds self.conv_op = conv_op self.num_classes = out_channels if not (0 <= dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") self.feat_size = (4, 6, 6,) self.hidden_size = hidden_size self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) self.encoder1 = UnetResBlock( spatial_dims=3, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, ) self.decoder5 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=8*12*12, ) self.decoder4 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=16*24*24, ) self.decoder3 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=32*48*48, ) self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=(1, 4, 4), norm_name=norm_name, out_size=32*192*192, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) if self.do_ds: self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) x = x.permute(0, 4, 1, 2, 3).contiguous() return x def forward(self, x_in): #print("#####input_shape:", x_in.shape) x_output, hidden_states = self.unetr_pp_encoder(x_in) convBlock = self.encoder1(x_in) # Four encoders enc1 = hidden_states[0] #print("ENC1:",enc1.shape) enc2 = hidden_states[1] #print("ENC2:",enc2.shape) enc3 = hidden_states[2] #print("ENC3:",enc3.shape) enc4 = hidden_states[3] #print("ENC4:",enc4.shape) # Four decoders dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) dec3 = self.decoder5(dec4, enc3) dec2 = self.decoder4(dec3, enc2) dec1 = self.decoder3(dec2, enc1) out = self.decoder2(dec1, convBlock) if self.do_ds: logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] else: logits = self.out1(out) return logits ================================================ FILE: unetr_pp/network_architecture/neural_network.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from batchgenerators.augmentations.utils import pad_nd_image from unetr_pp.utilities.random_stuff import no_op from unetr_pp.utilities.to_torch import to_cuda, maybe_to_torch from torch import nn import torch from scipy.ndimage.filters import gaussian_filter from typing import Union, Tuple, List from torch.cuda.amp import autocast class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() def get_device(self): if next(self.parameters()).device == "cpu": return "cpu" else: return next(self.parameters()).device.index def set_device(self, device): if device == "cpu": self.cpu() else: self.cuda(device) def forward(self, x): raise NotImplementedError class SegmentationNetwork(NeuralNetwork): def __init__(self): super(NeuralNetwork, self).__init__() # if we have 5 pooling then our patch size must be divisible by 2**5 self.input_shape_must_be_divisible_by = None # for example in a 2d network that does 5 pool in x and 6 pool # in y this would be (32, 64) # we need to know this because we need to know if we are a 2d or a 3d netowrk self.conv_op = None # nn.Conv2d or nn.Conv3d # this tells us how many channely we have in the output. Important for preallocation in inference self.num_classes = None # number of channels in the output # depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions # during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what # to apply in inference. For the most part this will be softmax self.inference_apply_nonlin = lambda x: x # softmax_helper # This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the # center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians # can be expensive, so it makes sense to save and reuse them. self._gaussian_3d = self._patch_size_for_gaussian_3d = None self._gaussian_2d = self._patch_size_for_gaussian_2d = None def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2), use_sliding_window: bool = False, step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None, use_gaussian: bool = False, pad_border_mode: str = "constant", pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Use this function to predict a 3D image. It does not matter whether the network is a 2D or 3D U-Net, it will detect that automatically and run the appropriate code. When running predictions, you need to specify whether you want to run fully convolutional of sliding window based inference. We very strongly recommend you use sliding window with the default settings. It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If the network is not in eval mode it will print a warning. :param x: Your input data. Must be a nd.ndarray of shape (c, x, y, z). :param do_mirroring: If True, use test time data augmentation in the form of mirroring :param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three axes :param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default :param step_size: When running sliding window prediction, the step size determines the distance between adjacent predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between predictions. step_size cannot be larger than 1! :param patch_size: The patch size that was used for training the network. Do not use different patch sizes here, this will either crash or give potentially less accurate segmentations :param regions_class_order: Fabian only :param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting to weigh predictions closer to the center of the current patch higher than those at the borders. The reason behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True :param pad_border_mode: leave this alone :param pad_kwargs: leave this alone :param all_in_gpu: experimental. You probably want to leave this as is it :param verbose: Do you want a wall of text? If yes then set this to True :param mixed_precision: if True, will run inference in mixed precision with autocast() :return: """ torch.cuda.empty_cache() assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \ 'predictions' if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes) assert self.get_device() != "cpu", "CPU not implemented" if pad_kwargs is None: pad_kwargs = {'constant_values': 0} # A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old # code that uses this convention self.conv_op = nn.Conv3d if len(mirror_axes): if self.conv_op == nn.Conv2d: if max(mirror_axes) > 1: raise ValueError("mirror axes. duh") if self.conv_op == nn.Conv3d: if max(mirror_axes) > 2: raise ValueError("mirror axes. duh") if self.training: print('WARNING! Network is in train mode during inference. This may be intended, or not...') assert len(x.shape) == 4, "data must have shape (c,x,y,z)" if mixed_precision: context = autocast else: context = no_op with context(): with torch.no_grad(): if self.conv_op == nn.Conv3d: if use_sliding_window: res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian, pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose) else: res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose) elif self.conv_op == nn.Conv2d: if use_sliding_window: res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size, regions_class_order, use_gaussian, pad_border_mode, pad_kwargs, all_in_gpu, False) else: res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, all_in_gpu, False) else: raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is") return res def predict_2D(self, x, do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), use_sliding_window: bool = False, step_size: float = 0.5, patch_size: tuple = None, regions_class_order: tuple = None, use_gaussian: bool = False, pad_border_mode: str = "constant", pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Use this function to predict a 2D image. If this is a 3D U-Net it will crash because you cannot predict a 2D image with that (you dummy). When running predictions, you need to specify whether you want to run fully convolutional of sliding window based inference. We very strongly recommend you use sliding window with the default settings. It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If the network is not in eval mode it will print a warning. :param x: Your input data. Must be a nd.ndarray of shape (c, x, y). :param do_mirroring: If True, use test time data augmentation in the form of mirroring :param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three axes :param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default :param step_size: When running sliding window prediction, the step size determines the distance between adjacent predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between predictions. step_size cannot be larger than 1! :param patch_size: The patch size that was used for training the network. Do not use different patch sizes here, this will either crash or give potentially less accurate segmentations :param regions_class_order: Fabian only :param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting to weigh predictions closer to the center of the current patch higher than those at the borders. The reason behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True :param pad_border_mode: leave this alone :param pad_kwargs: leave this alone :param all_in_gpu: experimental. You probably want to leave this as is it :param verbose: Do you want a wall of text? If yes then set this to True :return: """ torch.cuda.empty_cache() assert step_size <= 1, 'step_size must be smaler than 1. Otherwise there will be a gap between consecutive ' \ 'predictions' if self.conv_op == nn.Conv3d: raise RuntimeError("Cannot predict 2d if the network is 3d. Dummy.") if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes) assert self.get_device() != "cpu", "CPU not implemented" if pad_kwargs is None: pad_kwargs = {'constant_values': 0} # A very long time ago the mirror axes were (2, 3) for a 2d network. This is just to intercept any old # code that uses this convention if len(mirror_axes): if max(mirror_axes) > 1: raise ValueError("mirror axes. duh") if self.training: print('WARNING! Network is in train mode during inference. This may be intended, or not...') assert len(x.shape) == 3, "data must have shape (c,x,y)" if mixed_precision: context = autocast else: context = no_op with context(): with torch.no_grad(): if self.conv_op == nn.Conv2d: if use_sliding_window: res = self._internal_predict_2D_2Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian, pad_border_mode, pad_kwargs, all_in_gpu, verbose) else: res = self._internal_predict_2D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose) else: raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is") return res @staticmethod def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: tmp = np.zeros(patch_size) center_coords = [i // 2 for i in patch_size] sigmas = [i * sigma_scale for i in patch_size] tmp[tuple(center_coords)] = 1 gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 gaussian_importance_map = gaussian_importance_map.astype(np.float32) # gaussian_importance_map cannot be 0, otherwise we may end up with nans! gaussian_importance_map[gaussian_importance_map == 0] = np.min( gaussian_importance_map[gaussian_importance_map != 0]) return gaussian_importance_map @staticmethod def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]: assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size" assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 target_step_sizes_in_voxels = [i * step_size for i in patch_size] num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)] steps = [] for dim in range(len(patch_size)): # the highest step value for this dimension is max_step_value = image_size[dim] - patch_size[dim] if num_steps[dim] > 1: actual_step_size = max_step_value / (num_steps[dim] - 1) else: actual_step_size = 99999999999 # does not matter because there is only one step at 0 steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] steps.append(steps_here) return steps def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple, patch_size: tuple, regions_class_order: tuple, use_gaussian: bool, pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool, verbose: bool) -> Tuple[np.ndarray, np.ndarray]: # better safe than sorry assert len(x.shape) == 4, "x must be (c, x, y, z)" assert self.get_device() != "cpu" if verbose: print("step_size:", step_size) if verbose: print("do mirror:", do_mirroring) assert patch_size is not None, "patch_size cannot be None for tiled prediction" # for sliding window inference the image must at least be as large as the patch size. It does not matter # whether the shape is divisible by 2**num_pool as long as the patch size is data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None) data_shape = data.shape # still c, x, y, z # compute the steps for sliding window steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size) num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2]) if verbose: print("data shape:", data_shape) print("patch size:", patch_size) print("steps (x, y, and z):", steps) print("number of tiles:", num_tiles) # we only need to compute that once. It can take a while to compute this due to the large sigma in # gaussian_filter if use_gaussian and num_tiles > 1: if self._gaussian_3d is None or not all( [i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]): if verbose: print('computing Gaussian') gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8) self._gaussian_3d = gaussian_importance_map self._patch_size_for_gaussian_3d = patch_size else: if verbose: print("using precomputed Gaussian") gaussian_importance_map = self._gaussian_3d gaussian_importance_map = torch.from_numpy(gaussian_importance_map).cuda(self.get_device(), non_blocking=True) else: gaussian_importance_map = None if all_in_gpu: # If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces # CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU if use_gaussian and num_tiles > 1: # half precision for the outputs should be good enough. If the outputs here are half, the # gaussian_importance_map should be as well gaussian_importance_map = gaussian_importance_map.half() # make sure we did not round anything to 0 gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[ gaussian_importance_map != 0].min() add_for_nb_of_preds = gaussian_importance_map else: add_for_nb_of_preds = torch.ones(data.shape[1:], device=self.get_device()) if verbose: print("initializing result array (on GPU)") aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half, device=self.get_device()) if verbose: print("moving data to GPU") data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True) if verbose: print("initializing result_numsamples (on GPU)") aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half, device=self.get_device()) else: if use_gaussian and num_tiles > 1: add_for_nb_of_preds = self._gaussian_3d else: add_for_nb_of_preds = np.ones(data.shape[1:], dtype=np.float32) aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32) aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32) for x in steps[0]: lb_x = x ub_x = x + patch_size[0] for y in steps[1]: lb_y = y ub_y = y + patch_size[1] for z in steps[2]: lb_z = z ub_z = z + patch_size[2] predicted_patch = self._internal_maybe_mirror_and_pred_3D( data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring, gaussian_importance_map)[0] if all_in_gpu: predicted_patch = predicted_patch.half() else: predicted_patch = predicted_patch.cpu().numpy() aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds # we reverse the padding here (remeber that we padded the input to be at least as large as the patch size slicer = tuple( [slice(0, aggregated_results.shape[i]) for i in range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:]) aggregated_results = aggregated_results[slicer] aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer] # computing the class_probabilities by dividing the aggregated result with result_numsamples class_probabilities = aggregated_results / aggregated_nb_of_predictions if regions_class_order is None: predicted_segmentation = class_probabilities.argmax(0) else: if all_in_gpu: class_probabilities_here = class_probabilities.detach().cpu().numpy() else: class_probabilities_here = class_probabilities predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32) for i, c in enumerate(regions_class_order): predicted_segmentation[class_probabilities_here[i] > 0.5] = c if all_in_gpu: if verbose: print("copying results to CPU") if regions_class_order is None: predicted_segmentation = predicted_segmentation.detach().cpu().numpy() class_probabilities = class_probabilities.detach().cpu().numpy() if verbose: print("prediction done") return predicted_segmentation, class_probabilities def _internal_predict_2D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None, pad_border_mode: str = "constant", pad_kwargs: dict = None, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ This one does fully convolutional inference. No sliding window """ assert len(x.shape) == 3, "x must be (c, x, y)" assert self.get_device() != "cpu" assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \ 'run _internal_predict_2D_2Dconv' if verbose: print("do mirror:", do_mirroring) data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True, self.input_shape_must_be_divisible_by) predicted_probabilities = self._internal_maybe_mirror_and_pred_2D(data[None], mirror_axes, do_mirroring, None)[0] slicer = tuple( [slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) - (len(slicer) - 1))] + slicer[1:]) predicted_probabilities = predicted_probabilities[slicer] if regions_class_order is None: predicted_segmentation = predicted_probabilities.argmax(0) predicted_segmentation = predicted_segmentation.detach().cpu().numpy() predicted_probabilities = predicted_probabilities.detach().cpu().numpy() else: predicted_probabilities = predicted_probabilities.detach().cpu().numpy() predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32) for i, c in enumerate(regions_class_order): predicted_segmentation[predicted_probabilities[i] > 0.5] = c return predicted_segmentation, predicted_probabilities def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...], do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None, pad_border_mode: str = "constant", pad_kwargs: dict = None, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ This one does fully convolutional inference. No sliding window """ assert len(x.shape) == 4, "x must be (c, x, y, z)" assert self.get_device() != "cpu" assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \ 'run _internal_predict_3D_3Dconv' if verbose: print("do mirror:", do_mirroring) data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True, self.input_shape_must_be_divisible_by) predicted_probabilities = self._internal_maybe_mirror_and_pred_3D(data[None], mirror_axes, do_mirroring, None)[0] slicer = tuple( [slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) - (len(slicer) - 1))] + slicer[1:]) predicted_probabilities = predicted_probabilities[slicer] if regions_class_order is None: predicted_segmentation = predicted_probabilities.argmax(0) predicted_segmentation = predicted_segmentation.detach().cpu().numpy() predicted_probabilities = predicted_probabilities.detach().cpu().numpy() else: predicted_probabilities = predicted_probabilities.detach().cpu().numpy() predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32) for i, c in enumerate(regions_class_order): predicted_segmentation[predicted_probabilities[i] > 0.5] = c return predicted_segmentation, predicted_probabilities def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple, do_mirroring: bool = True, mult: np.ndarray or torch.tensor = None) -> torch.tensor: assert len(x.shape) == 5, 'x must be (b, c, x, y, z)' # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here # we now return a cuda tensor! Not numpy array! x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device()) result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]), dtype=torch.float).cuda(self.get_device(), non_blocking=True) if mult is not None: mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device()) if do_mirroring: mirror_idx = 8 num_results = 2 ** len(mirror_axes) else: mirror_idx = 1 num_results = 1 for m in range(mirror_idx): if m == 0: pred = self.inference_apply_nonlin(self(x)) result_torch += 1 / num_results * pred if m == 1 and (2 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (4, )))) result_torch += 1 / num_results * torch.flip(pred, (4,)) if m == 2 and (1 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (3, )))) result_torch += 1 / num_results * torch.flip(pred, (3,)) if m == 3 and (2 in mirror_axes) and (1 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3)))) result_torch += 1 / num_results * torch.flip(pred, (4, 3)) if m == 4 and (0 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (2, )))) result_torch += 1 / num_results * torch.flip(pred, (2,)) if m == 5 and (0 in mirror_axes) and (2 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2)))) result_torch += 1 / num_results * torch.flip(pred, (4, 2)) if m == 6 and (0 in mirror_axes) and (1 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2)))) result_torch += 1 / num_results * torch.flip(pred, (3, 2)) if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2)))) result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2)) if mult is not None: result_torch[:, :] *= mult return result_torch def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple, do_mirroring: bool = True, mult: np.ndarray or torch.tensor = None) -> torch.tensor: # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here # we now return a cuda tensor! Not numpy array! assert len(x.shape) == 4, 'x must be (b, c, x, y)' x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device()) result_torch = torch.zeros([x.shape[0], self.num_classes] + list(x.shape[2:]), dtype=torch.float).cuda(self.get_device(), non_blocking=True) if mult is not None: mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device()) if do_mirroring: mirror_idx = 4 num_results = 2 ** len(mirror_axes) else: mirror_idx = 1 num_results = 1 for m in range(mirror_idx): if m == 0: pred = self.inference_apply_nonlin(self(x)) result_torch += 1 / num_results * pred if m == 1 and (1 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (3, )))) result_torch += 1 / num_results * torch.flip(pred, (3, )) if m == 2 and (0 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (2, )))) result_torch += 1 / num_results * torch.flip(pred, (2, )) if m == 3 and (0 in mirror_axes) and (1 in mirror_axes): pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2)))) result_torch += 1 / num_results * torch.flip(pred, (3, 2)) if mult is not None: result_torch[:, :] *= mult return result_torch def _internal_predict_2D_2Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple, patch_size: tuple, regions_class_order: tuple, use_gaussian: bool, pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool, verbose: bool) -> Tuple[np.ndarray, np.ndarray]: # better safe than sorry assert len(x.shape) == 3, "x must be (c, x, y)" assert self.get_device() != "cpu" if verbose: print("step_size:", step_size) if verbose: print("do mirror:", do_mirroring) assert patch_size is not None, "patch_size cannot be None for tiled prediction" # for sliding window inference the image must at least be as large as the patch size. It does not matter # whether the shape is divisible by 2**num_pool as long as the patch size is data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None) data_shape = data.shape # still c, x, y # compute the steps for sliding window steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size) num_tiles = len(steps[0]) * len(steps[1]) if verbose: print("data shape:", data_shape) print("patch size:", patch_size) print("steps (x, y, and z):", steps) print("number of tiles:", num_tiles) # we only need to compute that once. It can take a while to compute this due to the large sigma in # gaussian_filter if use_gaussian and num_tiles > 1: if self._gaussian_2d is None or not all( [i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_2d)]): if verbose: print('computing Gaussian') gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8) self._gaussian_2d = gaussian_importance_map self._patch_size_for_gaussian_2d = patch_size else: if verbose: print("using precomputed Gaussian") gaussian_importance_map = self._gaussian_2d gaussian_importance_map = torch.from_numpy(gaussian_importance_map).cuda(self.get_device(), non_blocking=True) else: gaussian_importance_map = None if all_in_gpu: # If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces # CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU if use_gaussian and num_tiles > 1: # half precision for the outputs should be good enough. If the outputs here are half, the # gaussian_importance_map should be as well gaussian_importance_map = gaussian_importance_map.half() # make sure we did not round anything to 0 gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[ gaussian_importance_map != 0].min() add_for_nb_of_preds = gaussian_importance_map else: add_for_nb_of_preds = torch.ones(data.shape[1:], device=self.get_device()) if verbose: print("initializing result array (on GPU)") aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half, device=self.get_device()) if verbose: print("moving data to GPU") data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True) if verbose: print("initializing result_numsamples (on GPU)") aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half, device=self.get_device()) else: if use_gaussian and num_tiles > 1: add_for_nb_of_preds = self._gaussian_2d else: add_for_nb_of_preds = np.ones(data.shape[1:], dtype=np.float32) aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32) aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32) for x in steps[0]: lb_x = x ub_x = x + patch_size[0] for y in steps[1]: lb_y = y ub_y = y + patch_size[1] predicted_patch = self._internal_maybe_mirror_and_pred_2D( data[None, :, lb_x:ub_x, lb_y:ub_y], mirror_axes, do_mirroring, gaussian_importance_map)[0] if all_in_gpu: predicted_patch = predicted_patch.half() else: predicted_patch = predicted_patch.cpu().numpy() aggregated_results[:, lb_x:ub_x, lb_y:ub_y] += predicted_patch aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y] += add_for_nb_of_preds # we reverse the padding here (remeber that we padded the input to be at least as large as the patch size slicer = tuple( [slice(0, aggregated_results.shape[i]) for i in range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:]) aggregated_results = aggregated_results[slicer] aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer] # computing the class_probabilities by dividing the aggregated result with result_numsamples class_probabilities = aggregated_results / aggregated_nb_of_predictions if regions_class_order is None: predicted_segmentation = class_probabilities.argmax(0) else: if all_in_gpu: class_probabilities_here = class_probabilities.detach().cpu().numpy() else: class_probabilities_here = class_probabilities predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32) for i, c in enumerate(regions_class_order): predicted_segmentation[class_probabilities_here[i] > 0.5] = c if all_in_gpu: if verbose: print("copying results to CPU") if regions_class_order is None: predicted_segmentation = predicted_segmentation.detach().cpu().numpy() class_probabilities = class_probabilities.detach().cpu().numpy() if verbose: print("prediction done") return predicted_segmentation, class_probabilities def _internal_predict_3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool, mirror_axes: tuple = (0, 1), regions_class_order: tuple = None, pad_border_mode: str = "constant", pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: if all_in_gpu: raise NotImplementedError assert len(x.shape) == 4, "data must be c, x, y, z" predicted_segmentation = [] softmax_pred = [] for s in range(x.shape[1]): pred_seg, softmax_pres = self._internal_predict_2D_2Dconv( x[:, s], min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose) predicted_segmentation.append(pred_seg[None]) softmax_pred.append(softmax_pres[None]) predicted_segmentation = np.vstack(predicted_segmentation) softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3)) return predicted_segmentation, softmax_pred def predict_3D_pseudo3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool, mirror_axes: tuple = (0, 1), regions_class_order: tuple = None, pseudo3D_slices: int = 5, all_in_gpu: bool = False, pad_border_mode: str = "constant", pad_kwargs: dict = None, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: if all_in_gpu: raise NotImplementedError assert len(x.shape) == 4, "data must be c, x, y, z" assert pseudo3D_slices % 2 == 1, "pseudo3D_slices must be odd" extra_slices = (pseudo3D_slices - 1) // 2 shp_for_pad = np.array(x.shape) shp_for_pad[1] = extra_slices pad = np.zeros(shp_for_pad, dtype=np.float32) data = np.concatenate((pad, x, pad), 1) predicted_segmentation = [] softmax_pred = [] for s in range(extra_slices, data.shape[1] - extra_slices): d = data[:, (s - extra_slices):(s + extra_slices + 1)] d = d.reshape((-1, d.shape[-2], d.shape[-1])) pred_seg, softmax_pres = \ self._internal_predict_2D_2Dconv(d, min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose) predicted_segmentation.append(pred_seg[None]) softmax_pred.append(softmax_pres[None]) predicted_segmentation = np.vstack(predicted_segmentation) softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3)) return predicted_segmentation, softmax_pred def _internal_predict_3D_2Dconv_tiled(self, x: np.ndarray, patch_size: Tuple[int, int], do_mirroring: bool, mirror_axes: tuple = (0, 1), step_size: float = 0.5, regions_class_order: tuple = None, use_gaussian: bool = False, pad_border_mode: str = "edge", pad_kwargs: dict =None, all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: if all_in_gpu: raise NotImplementedError assert len(x.shape) == 4, "data must be c, x, y, z" predicted_segmentation = [] softmax_pred = [] for s in range(x.shape[1]): pred_seg, softmax_pres = self._internal_predict_2D_2Dconv_tiled( x[:, s], step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian, pad_border_mode, pad_kwargs, all_in_gpu, verbose) predicted_segmentation.append(pred_seg[None]) softmax_pred.append(softmax_pres[None]) predicted_segmentation = np.vstack(predicted_segmentation) softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3)) return predicted_segmentation, softmax_pred if __name__ == '__main__': print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.5)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 1)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.1)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 1)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 0.5)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 1)) print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 0.125)) print(SegmentationNetwork._compute_steps_for_sliding_window((123, 54, 123), (246, 162, 369), 0.25)) ================================================ FILE: unetr_pp/network_architecture/synapse/__init__.py ================================================ ================================================ FILE: unetr_pp/network_architecture/synapse/model_components.py ================================================ from torch import nn from timm.models.layers import trunc_normal_ from typing import Sequence, Tuple, Union from monai.networks.layers.utils import get_norm_layer from monai.utils import optional_import from unetr_pp.network_architecture.layers import LayerNorm from unetr_pp.network_architecture.synapse.transformerblock import TransformerBlock from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock einops, _ = optional_import("einops") class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) self.downsample_layers.append(stem_layer) for i in range(3): downsample_layer = nn.Sequential( get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks for i in range(4): stage_blocks = [] for j in range(depths[i]): stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, dropout_rate=transformer_dropout_rate, pos_embed=True)) self.stages.append(nn.Sequential(*stage_blocks)) self.hidden_states = [] self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (LayerNorm, nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): hidden_states = [] x = self.downsample_layers[0](x) x = self.stages[0](x) hidden_states.append(x) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i == 3: # Reshape the output of the last stage x = einops.rearrange(x, "b c h w d -> b (h w d) c") hidden_states.append(x) return x, hidden_states def forward(self, x): x, hidden_states = self.forward_features(x) return x, hidden_states class UnetrUpBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], proj_size: int = 64, num_heads: int = 4, out_size: int = 0, depth: int = 3, conv_decoder: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of heads inside each EPA module. out_size: spatial size for each decoder. depth: number of blocks for the current decoder stage. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) if conv_decoder == True: self.decoder_block.append( UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, )) else: stage_blocks = [] for j in range(depth): stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, dropout_rate=0.15, pos_embed=True)) self.decoder_block.append(nn.Sequential(*stage_blocks)) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, inp, skip): out = self.transp_conv(inp) out = out + skip out = self.decoder_block[0](out) return out ================================================ FILE: unetr_pp/network_architecture/synapse/transformerblock.py ================================================ import torch.nn as nn import torch from unetr_pp.network_architecture.dynunet_block import UnetResBlock class TransformerBlock(nn.Module): """ A transformer block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, input_size: int, hidden_size: int, proj_size: int, num_heads: int, dropout_rate: float = 0.0, pos_embed=False, ) -> None: """ Args: input_size: the size of the input for each stage. hidden_size: dimension of hidden layer. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. pos_embed: bool argument to determine if positional embedding is used. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: print("Hidden size is ", hidden_size) print("Num heads is ", num_heads) raise ValueError("hidden_size should be divisible by num_heads.") self.norm = nn.LayerNorm(hidden_size) self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) self.pos_embed = None if pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) def forward(self, x): B, C, H, W, D = x.shape x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: x = x + self.pos_embed attn = x + self.gamma * self.epa_block(self.norm(x)) attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) attn = self.conv51(attn_skip) x = attn_skip + self.conv8(attn) return x class EPA(nn.Module): """ Efficient Paired Attention Block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): super().__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) # E and F are projection matrices with shared weights used in spatial attention module to project # keys and values from HWD-dimension to P-dimension self.E = self.F = nn.Linear(input_size, proj_size) self.attn_drop = nn.Dropout(channel_attn_drop) self.attn_drop_2 = nn.Dropout(spatial_attn_drop) self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) def forward(self, x): B, N, C = x.shape qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) qkvv = qkvv.permute(2, 0, 3, 1, 4) q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] q_shared = q_shared.transpose(-2, -1) k_shared = k_shared.transpose(-2, -1) v_CA = v_CA.transpose(-2, -1) v_SA = v_SA.transpose(-2, -1) k_shared_projected = self.E(k_shared) v_SA_projected = self.F(v_SA) q_shared = torch.nn.functional.normalize(q_shared, dim=-1) k_shared = torch.nn.functional.normalize(k_shared, dim=-1) attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature attn_CA = attn_CA.softmax(dim=-1) attn_CA = self.attn_drop(attn_CA) x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 attn_SA = attn_SA.softmax(dim=-1) attn_SA = self.attn_drop_2(attn_SA) x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) # Concat fusion x_SA = self.out_proj(x_SA) x_CA = self.out_proj2(x_CA) x = torch.cat((x_SA, x_CA), dim=-1) return x @torch.jit.ignore def no_weight_decay(self): return {'temperature', 'temperature2'} ================================================ FILE: unetr_pp/network_architecture/synapse/unetr_pp_synapse.py ================================================ from torch import nn from typing import Tuple, Union from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock from unetr_pp.network_architecture.synapse.model_components import UnetrPPEncoder, UnetrUpBlock class UNETR_PP(SegmentationNetwork): """ UNETR++ based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, in_channels: int, out_channels: int, img_size: [64, 128, 128], feature_size: int = 16, hidden_size: int = 256, num_heads: int = 4, pos_embed: str = "perceptron", # TODO: Remove the argument norm_name: Union[Tuple, str] = "instance", dropout_rate: float = 0.0, depths=None, dims=None, conv_op=nn.Conv3d, do_ds=True, ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. img_size: dimension of input image. feature_size: dimension of network feature size. hidden_size: dimension of the last encoder. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. dropout_rate: faction of the input units to drop. depths: number of blocks for each stage. dims: number of channel maps for the stages. conv_op: type of convolution operation. do_ds: use deep supervision to compute the loss. Examples:: # for single channel input 4-channel output with patch size of (64, 128, 128), feature size of 16, batch norm and depths of [3, 3, 3, 3] with output channels [32, 64, 128, 256], 4 heads, and 14 classes with deep supervision: >>> net = UNETR_PP(in_channels=1, out_channels=14, img_size=(64, 128, 128), feature_size=16, num_heads=4, >>> norm_name='batch', depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True) """ super().__init__() if depths is None: depths = [3, 3, 3, 3] self.do_ds = do_ds self.conv_op = conv_op self.num_classes = out_channels if not (0 <= dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") self.patch_size = (2, 4, 4) self.feat_size = ( img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages ) self.hidden_size = hidden_size self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) self.encoder1 = UnetResBlock( spatial_dims=3, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, ) self.decoder5 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=8 * 8 * 8, ) self.decoder4 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=16 * 16 * 16, ) self.decoder3 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=32 * 32 * 32, ) self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=(2, 4, 4), norm_name=norm_name, out_size=64 * 128 * 128, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) if self.do_ds: self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) x = x.permute(0, 4, 1, 2, 3).contiguous() return x def forward(self, x_in): x_output, hidden_states = self.unetr_pp_encoder(x_in) convBlock = self.encoder1(x_in) # Four encoders enc1 = hidden_states[0] enc2 = hidden_states[1] enc3 = hidden_states[2] enc4 = hidden_states[3] # Four decoders dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) dec3 = self.decoder5(dec4, enc3) dec2 = self.decoder4(dec3, enc2) dec1 = self.decoder3(dec2, enc1) out = self.decoder2(dec1, convBlock) if self.do_ds: logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] else: logits = self.out1(out) return logits ================================================ FILE: unetr_pp/network_architecture/tumor/__init__.py ================================================ ================================================ FILE: unetr_pp/network_architecture/tumor/model_components.py ================================================ from torch import nn from timm.models.layers import trunc_normal_ from typing import Sequence, Tuple, Union from monai.networks.layers.utils import get_norm_layer from monai.utils import optional_import from unetr_pp.network_architecture.layers import LayerNorm from unetr_pp.network_architecture.tumor.transformerblock import TransformerBlock from unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock einops, _ = optional_import("einops") class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=4, dropout=0.0, transformer_dropout_rate=0.1 ,**kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(4, 4, 4), stride=(4, 4, 4), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) self.downsample_layers.append(stem_layer) for i in range(3): downsample_layer = nn.Sequential( get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks for i in range(4): stage_blocks = [] for j in range(depths[i]): stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, dropout_rate=transformer_dropout_rate, pos_embed=True)) self.stages.append(nn.Sequential(*stage_blocks)) self.hidden_states = [] self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (LayerNorm, nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): hidden_states = [] x = self.downsample_layers[0](x) x = self.stages[0](x) hidden_states.append(x) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i == 3: # Reshape the output of the last stage x = einops.rearrange(x, "b c h w d -> b (h w d) c") hidden_states.append(x) return x, hidden_states def forward(self, x): x, hidden_states = self.forward_features(x) return x, hidden_states class UnetrUpBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], proj_size: int = 64, num_heads: int = 4, out_size: int = 0, depth: int = 3, conv_decoder: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of heads inside each EPA module. out_size: spatial size for each decoder. depth: number of blocks for the current decoder stage. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block # (see suppl. material in the paper) if conv_decoder == True: self.decoder_block.append( UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, )) else: stage_blocks = [] for j in range(depth): stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, dropout_rate=0.1, pos_embed=True)) self.decoder_block.append(nn.Sequential(*stage_blocks)) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, inp, skip): out = self.transp_conv(inp) out = out + skip out = self.decoder_block[0](out) return out ================================================ FILE: unetr_pp/network_architecture/tumor/transformerblock.py ================================================ import torch.nn as nn import torch from unetr_pp.network_architecture.dynunet_block import UnetResBlock import math class TransformerBlock(nn.Module): """ A transformer block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, input_size: int, hidden_size: int, proj_size: int, num_heads: int, dropout_rate: float = 0.0, pos_embed=False, ) -> None: """ Args: input_size: the size of the input for each stage. hidden_size: dimension of hidden layer. proj_size: projection size for keys and values in the spatial attention module. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. pos_embed: bool argument to determine if positional embedding is used. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: print("Hidden size is ", hidden_size) print("Num heads is ", num_heads) raise ValueError("hidden_size should be divisible by num_heads.") self.norm = nn.LayerNorm(hidden_size) self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) self.pos_embed = None if pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) def forward(self, x): B, C, H, W, D = x.shape x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: x = x + self.pos_embed attn = x + self.gamma * self.epa_block(self.norm(x)) attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) attn = self.conv51(attn_skip) x = attn_skip + self.conv8(attn) return x def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor class EPA(nn.Module): """ Efficient Paired Attention Block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): super().__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) # E and F are projection matrices with shared weights used in spatial attention module to project # keys and values from HWD-dimension to P-dimension self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size))) self.attn_drop = nn.Dropout(channel_attn_drop) self.attn_drop_2 = nn.Dropout(spatial_attn_drop) def forward(self, x): B, N, C = x.shape qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) qkvv = qkvv.permute(2, 0, 3, 1, 4) q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] q_shared = q_shared.transpose(-2, -1) k_shared = k_shared.transpose(-2, -1) v_CA = v_CA.transpose(-2, -1) v_SA = v_SA.transpose(-2, -1) proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args) k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF))) q_shared = torch.nn.functional.normalize(q_shared, dim=-1) k_shared = torch.nn.functional.normalize(k_shared, dim=-1) attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature attn_CA = attn_CA.softmax(dim=-1) attn_CA = self.attn_drop(attn_CA) x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 attn_SA = attn_SA.softmax(dim=-1) attn_SA = self.attn_drop_2(attn_SA) x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) return x_CA + x_SA @torch.jit.ignore def no_weight_decay(self): return {'temperature', 'temperature2'} ================================================ FILE: unetr_pp/network_architecture/tumor/unetr_pp_tumor.py ================================================ from torch import nn from typing import Tuple, Union from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock from unetr_pp.network_architecture.tumor.model_components import UnetrPPEncoder, UnetrUpBlock class UNETR_PP(SegmentationNetwork): """ UNETR++ based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" """ def __init__( self, in_channels: int, out_channels: int, feature_size: int = 16, hidden_size: int = 256, num_heads: int = 4, pos_embed: str = "perceptron", norm_name: Union[Tuple, str] = "instance", dropout_rate: float = 0.0, depths=None, dims=None, conv_op=nn.Conv3d, do_ds=True, ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. img_size: dimension of input image. feature_size: dimension of network feature size. hidden_size: dimensions of the last encoder. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. dropout_rate: faction of the input units to drop. depths: number of blocks for each stage. dims: number of channel maps for the stages. conv_op: type of convolution operation. do_ds: use deep supervision to compute the loss. """ super().__init__() if depths is None: depths = [3, 3, 3, 3] self.do_ds = do_ds self.conv_op = conv_op self.num_classes = out_channels if not (0 <= dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") self.feat_size = (4, 4, 4,) self.hidden_size = hidden_size self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) self.encoder1 = UnetResBlock( spatial_dims=3, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, ) self.decoder5 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=8*8*8, ) self.decoder4 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=16*16*16, ) self.decoder3 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, out_size=32*32*32, ) self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=(4, 4, 4), norm_name=norm_name, out_size=128*128*128, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) if self.do_ds: self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) x = x.permute(0, 4, 1, 2, 3).contiguous() return x def forward(self, x_in): #print("###########reached forward network") #print("XIN",x_in.shape) x_output, hidden_states = self.unetr_pp_encoder(x_in) convBlock = self.encoder1(x_in) # Four encoders enc1 = hidden_states[0] enc2 = hidden_states[1] enc3 = hidden_states[2] enc4 = hidden_states[3] # Four decoders dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) dec3 = self.decoder5(dec4, enc3) dec2 = self.decoder4(dec3, enc2) dec1 = self.decoder3(dec2, enc1) out = self.decoder2(dec1, convBlock) if self.do_ds: logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] else: logits = self.out1(out) return logits ================================================ FILE: unetr_pp/paths.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join # do not modify these unless you know what you are doing my_output_identifier = "unetr_pp" default_plans_identifier = "unetr_pp_Plansv2.1" default_data_identifier = 'unetr_pp_Data_plans_v2.1' default_trainer = "unetr_pp_trainer_synapse" """ PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP """ base = os.environ['unetr_pp_raw_data_base'] if "unetr_pp_raw_data_base" in os.environ.keys() else None preprocessing_output_dir = os.environ['unetr_pp_preprocessed'] if "unetr_pp_preprocessed" in os.environ.keys() else None network_training_output_dir_base = os.path.join(os.environ['RESULTS_FOLDER']) if "RESULTS_FOLDER" in os.environ.keys() else None if base is not None: nnFormer_raw_data = join(base, "unetr_pp_raw_data") nnFormer_cropped_data = join(base, "unetr_pp_cropped_data") maybe_mkdir_p(nnFormer_raw_data) maybe_mkdir_p(nnFormer_cropped_data) else: print("unetr_pp_raw_data_base is not defined and model can only be used on data for which preprocessed files " "are already present on your system. model cannot be used for experiment planning and preprocessing like " "this. If this is not intended, please read run_training_synapse.sh/run_training_acdc.sh " "for information on how to set this up properly.") nnFormer_cropped_data = nnFormer_raw_data = None if preprocessing_output_dir is not None: maybe_mkdir_p(preprocessing_output_dir) else: print("unetr_pp_preprocessed is not defined and model can not be used for preprocessing " "or training. If this is not intended, please read documentation/setting_up_paths.md for " "information on how to set this up.") preprocessing_output_dir = None if network_training_output_dir_base is not None: network_training_output_dir = join(network_training_output_dir_base, my_output_identifier) maybe_mkdir_p(network_training_output_dir) else: print("RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or " "inference. If this is not intended behavior, please read run_training_synapse.sh/run_training_acdc.sh " "for information on how to set this up.") network_training_output_dir = None ================================================ FILE: unetr_pp/postprocessing/connected_components.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import deepcopy from multiprocessing.pool import Pool import numpy as np from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from scipy.ndimage import label import SimpleITK as sitk from unetr_pp.utilities.sitk_stuff import copy_geometry from batchgenerators.utilities.file_and_folder_operations import * import shutil def load_remove_save(input_file: str, output_file: str, for_which_classes: list, minimum_valid_object_size: dict = None): # Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must # match entries in for_which_classes img_in = sitk.ReadImage(input_file) img_npy = sitk.GetArrayFromImage(img_in) volume_per_voxel = float(np.prod(img_in.GetSpacing(), dtype=np.float64)) image, largest_removed, kept_size = remove_all_but_the_largest_connected_component(img_npy, for_which_classes, volume_per_voxel, minimum_valid_object_size) # print(input_file, "kept:", kept_size) img_out_itk = sitk.GetImageFromArray(image) img_out_itk = copy_geometry(img_out_itk, img_in) sitk.WriteImage(img_out_itk, output_file) return largest_removed, kept_size def remove_all_but_the_largest_connected_component(image: np.ndarray, for_which_classes: list, volume_per_voxel: float, minimum_valid_object_size: dict = None): """ removes all but the largest connected component, individually for each class :param image: :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4]. Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2) to use all foreground classes together) :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must match entries in for_which_classes :return: """ if for_which_classes is None: for_which_classes = np.unique(image) for_which_classes = for_which_classes[for_which_classes > 0] assert 0 not in for_which_classes, "cannot remove background" largest_removed = {} kept_size = {} for c in for_which_classes: if isinstance(c, (list, tuple)): c = tuple(c) # otherwise it cant be used as key in the dict mask = np.zeros_like(image, dtype=bool) for cl in c: mask[image == cl] = True else: mask = image == c # get labelmap and number of objects lmap, num_objects = label(mask.astype(int)) # collect object sizes object_sizes = {} for object_id in range(1, num_objects + 1): object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel largest_removed[c] = None kept_size[c] = None if num_objects > 0: # we always keep the largest object. We could also consider removing the largest object if it is smaller # than minimum_valid_object_size in the future but we don't do that now. maximum_size = max(object_sizes.values()) kept_size[c] = maximum_size for object_id in range(1, num_objects + 1): # we only remove objects that are not the largest if object_sizes[object_id] != maximum_size: # we only remove objects that are smaller than minimum_valid_object_size remove = True if minimum_valid_object_size is not None: remove = object_sizes[object_id] < minimum_valid_object_size[c] if remove: image[(lmap == object_id) & mask] = 0 if largest_removed[c] is None: largest_removed[c] = object_sizes[object_id] else: largest_removed[c] = max(largest_removed[c], object_sizes[object_id]) return image, largest_removed, kept_size def load_postprocessing(json_file): ''' loads the relevant part of the pkl file that is needed for applying postprocessing :param pkl_file: :return: ''' a = load_json(json_file) if 'min_valid_object_sizes' in a.keys(): min_valid_object_sizes = ast.literal_eval(a['min_valid_object_sizes']) else: min_valid_object_sizes = None return a['for_which_classes'], min_valid_object_sizes def determine_postprocessing(base, gt_labels_folder, raw_subfolder_name="validation_raw", temp_folder="temp", final_subf_name="validation_final", processes=default_num_threads, dice_threshold=0, debug=False, advanced_postprocessing=False, pp_filename="postprocessing.json"): """ :param base: :param gt_labels_folder: subfolder of base with niftis of ground truth labels :param raw_subfolder_name: subfolder of base with niftis of predicted (non-postprocessed) segmentations :param temp_folder: used to store temporary data, will be deleted after we are done here undless debug=True :param final_subf_name: final results will be stored here (subfolder of base) :param processes: :param dice_threshold: only apply postprocessing if results is better than old_result+dice_threshold (can be used as eps) :param debug: if True then the temporary files will not be deleted :return: """ # lets see what classes are in the dataset classes = [int(i) for i in load_json(join(base, raw_subfolder_name, "summary.json"))['results']['mean'].keys() if int(i) != 0] folder_all_classes_as_fg = join(base, temp_folder + "_allClasses") folder_per_class = join(base, temp_folder + "_perClass") if isdir(folder_all_classes_as_fg): shutil.rmtree(folder_all_classes_as_fg) if isdir(folder_per_class): shutil.rmtree(folder_per_class) # multiprocessing rules p = Pool(processes) assert isfile(join(base, raw_subfolder_name, "summary.json")), "join(base, raw_subfolder_name) does not " \ "contain a summary.json" # these are all the files we will be dealing with fnames = subfiles(join(base, raw_subfolder_name), suffix=".nii.gz", join=False) # make output and temp dir maybe_mkdir_p(folder_all_classes_as_fg) maybe_mkdir_p(folder_per_class) maybe_mkdir_p(join(base, final_subf_name)) pp_results = {} pp_results['dc_per_class_raw'] = {} pp_results['dc_per_class_pp_all'] = {} # dice scores after treating all foreground classes as one pp_results['dc_per_class_pp_per_class'] = {} # dice scores after removing everything except larges cc # independently for each class after we already did dc_per_class_pp_all pp_results['for_which_classes'] = [] pp_results['min_valid_object_sizes'] = {} validation_result_raw = load_json(join(base, raw_subfolder_name, "summary.json"))['results'] pp_results['num_samples'] = len(validation_result_raw['all']) validation_result_raw = validation_result_raw['mean'] if advanced_postprocessing: # first treat all foreground classes as one and remove all but the largest foreground connected component results = [] for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(folder_all_classes_as_fg, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,)),))) results = [i.get() for i in results] # aggregate max_size_removed and min_size_kept max_size_removed = {} min_size_kept = {} for tmp in results: mx_rem, min_kept = tmp[0] for k in mx_rem: if mx_rem[k] is not None: if max_size_removed.get(k) is None: max_size_removed[k] = mx_rem[k] else: max_size_removed[k] = max(max_size_removed[k], mx_rem[k]) for k in min_kept: if min_kept[k] is not None: if min_size_kept.get(k) is None: min_size_kept[k] = min_kept[k] else: min_size_kept[k] = min(min_size_kept[k], min_kept[k]) print("foreground vs background, smallest valid object size was", min_size_kept[tuple(classes)]) print("removing only objects smaller than that...") else: min_size_kept = None # we need to rerun the step from above, now with the size constraint pred_gt_tuples = [] results = [] # first treat all foreground classes as one and remove all but the largest foreground connected component for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(folder_all_classes_as_fg, f) results.append( p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,), min_size_kept),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(folder_all_classes_as_fg, "summary.json"), json_author="Fabian", num_threads=processes) # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far # that if a single class got worse as a result we won't do this. We can change this in the future but right now I # prefer to do it this way validation_result_PP_test = load_json(join(folder_all_classes_as_fg, "summary.json"))['results']['mean'] for c in classes: dc_raw = validation_result_raw[str(c)]['Dice'] dc_pp = validation_result_PP_test[str(c)]['Dice'] pp_results['dc_per_class_raw'][str(c)] = dc_raw pp_results['dc_per_class_pp_all'][str(c)] = dc_pp # true if new is better do_fg_cc = False comp = [pp_results['dc_per_class_pp_all'][str(cl)] > (pp_results['dc_per_class_raw'][str(cl)] + dice_threshold) for cl in classes] before = np.mean([pp_results['dc_per_class_raw'][str(cl)] for cl in classes]) after = np.mean([pp_results['dc_per_class_pp_all'][str(cl)] for cl in classes]) print("Foreground vs background") print("before:", before) print("after: ", after) if any(comp): # at least one class improved - yay! # now check if another got worse # true if new is worse any_worse = any( [pp_results['dc_per_class_pp_all'][str(cl)] < pp_results['dc_per_class_raw'][str(cl)] for cl in classes]) if not any_worse: pp_results['for_which_classes'].append(classes) if min_size_kept is not None: pp_results['min_valid_object_sizes'].update(deepcopy(min_size_kept)) do_fg_cc = True print("Removing all but the largest foreground region improved results!") print('for_which_classes', classes) print('min_valid_object_sizes', min_size_kept) else: # did not improve things - don't do it pass if len(classes) > 1: # now depending on whether we do remove all but the largest foreground connected component we define the source dir # for the next one to be the raw or the temp dir if do_fg_cc: source = folder_all_classes_as_fg else: source = join(base, raw_subfolder_name) if advanced_postprocessing: # now run this for each class separately results = [] for f in fnames: predicted_segmentation = join(source, f) output_file = join(folder_per_class, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes),))) results = [i.get() for i in results] # aggregate max_size_removed and min_size_kept max_size_removed = {} min_size_kept = {} for tmp in results: mx_rem, min_kept = tmp[0] for k in mx_rem: if mx_rem[k] is not None: if max_size_removed.get(k) is None: max_size_removed[k] = mx_rem[k] else: max_size_removed[k] = max(max_size_removed[k], mx_rem[k]) for k in min_kept: if min_kept[k] is not None: if min_size_kept.get(k) is None: min_size_kept[k] = min_kept[k] else: min_size_kept[k] = min(min_size_kept[k], min_kept[k]) print("classes treated separately, smallest valid object sizes are") print(min_size_kept) print("removing only objects smaller than that") else: min_size_kept = None # rerun with the size thresholds from above pred_gt_tuples = [] results = [] for f in fnames: predicted_segmentation = join(source, f) output_file = join(folder_per_class, f) results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes, min_size_kept),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(folder_per_class, "summary.json"), json_author="Fabian", num_threads=processes) if do_fg_cc: old_res = deepcopy(validation_result_PP_test) else: old_res = validation_result_raw # these are the new dice scores validation_result_PP_test = load_json(join(folder_per_class, "summary.json"))['results']['mean'] for c in classes: dc_raw = old_res[str(c)]['Dice'] dc_pp = validation_result_PP_test[str(c)]['Dice'] pp_results['dc_per_class_pp_per_class'][str(c)] = dc_pp print(c) print("before:", dc_raw) print("after: ", dc_pp) if dc_pp > (dc_raw + dice_threshold): pp_results['for_which_classes'].append(int(c)) if min_size_kept is not None: pp_results['min_valid_object_sizes'].update({c: min_size_kept[c]}) print("Removing all but the largest region for class %d improved results!" % c) print('min_valid_object_sizes', min_size_kept) else: print("Only one class present, no need to do each class separately as this is covered in fg vs bg") if not advanced_postprocessing: pp_results['min_valid_object_sizes'] = None print("done") print("for which classes:") print(pp_results['for_which_classes']) print("min_object_sizes") print(pp_results['min_valid_object_sizes']) pp_results['validation_raw'] = raw_subfolder_name pp_results['validation_final'] = final_subf_name # now that we have a proper for_which_classes, apply that pred_gt_tuples = [] results = [] for f in fnames: predicted_segmentation = join(base, raw_subfolder_name, f) # now remove all but the largest connected component for each class output_file = join(base, final_subf_name, f) results.append(p.starmap_async(load_remove_save, ( (predicted_segmentation, output_file, pp_results['for_which_classes'], pp_results['min_valid_object_sizes']),))) pred_gt_tuples.append([output_file, join(gt_labels_folder, f)]) _ = [i.get() for i in results] # evaluate postprocessed predictions _ = aggregate_scores(pred_gt_tuples, labels=classes, json_output_file=join(base, final_subf_name, "summary.json"), json_author="Fabian", num_threads=processes) pp_results['min_valid_object_sizes'] = str(pp_results['min_valid_object_sizes']) save_json(pp_results, join(base, pp_filename)) # delete temp if not debug: shutil.rmtree(folder_per_class) shutil.rmtree(folder_all_classes_as_fg) p.close() p.join() print("done") def apply_postprocessing_to_folder(input_folder: str, output_folder: str, for_which_classes: list, min_valid_object_size:dict=None, num_processes=8): """ applies removing of all but the largest connected component to all niftis in a folder :param min_valid_object_size: :param min_valid_object_size: :param input_folder: :param output_folder: :param for_which_classes: :param num_processes: :return: """ maybe_mkdir_p(output_folder) p = Pool(num_processes) nii_files = subfiles(input_folder, suffix=".nii.gz", join=False) input_files = [join(input_folder, i) for i in nii_files] out_files = [join(output_folder, i) for i in nii_files] results = p.starmap_async(load_remove_save, zip(input_files, out_files, [for_which_classes] * len(input_files), [min_valid_object_size] * len(input_files))) res = results.get() p.close() p.join() if __name__ == "__main__": input_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor" output_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor_postprocessed" for_which_classes = [(1, 2), ] apply_postprocessing_to_folder(input_folder, output_folder, for_which_classes) ================================================ FILE: unetr_pp/postprocessing/consolidate_all_for_paper.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unetr_pp.utilities.folder_names import get_output_folder_name def get_datasets(): configurations_all = { "Task01_BrainTumour": ("3d_fullres", "2d"), "Task02_Heart": ("3d_fullres", "2d",), "Task03_Liver": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task04_Hippocampus": ("3d_fullres", "2d",), "Task05_Prostate": ("3d_fullres", "2d",), "Task06_Lung": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task07_Pancreas": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task08_HepaticVessel": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task09_Spleen": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task10_Colon": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), "Task48_KiTS_clean": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d"), "Task27_ACDC": ("3d_fullres", "2d",), "Task24_Promise": ("3d_fullres", "2d",), "Task35_ISBILesionSegmentation": ("3d_fullres", "2d",), "Task38_CHAOS_Task_3_5_Variant2": ("3d_fullres", "2d",), "Task29_LITS": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",), "Task17_AbdominalOrganSegmentation": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",), "Task55_SegTHOR": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",), "Task56_VerSe": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",), } return configurations_all def get_commands(configurations, regular_trainer="nnFormerTrainerV2", cascade_trainer="nnFormerTrainerV2CascadeFullRes", plans="nnFormerPlansv2.1"): node_pool = ["hdf18-gpu%02.0d" % i for i in range(1, 21)] + ["hdf19-gpu%02.0d" % i for i in range(1, 8)] + ["hdf19-gpu%02.0d" % i for i in range(11, 16)] ctr = 0 for task in configurations: models = configurations[task] for m in models: if m == "3d_cascade_fullres": trainer = cascade_trainer else: trainer = regular_trainer folder = get_output_folder_name(m, task, trainer, plans, overwrite_training_output_dir="/datasets/datasets_fabian/results/nnFormer") node = node_pool[ctr % len(node_pool)] print("bsub -m %s -q gputest -L /bin/bash \"source ~/.bashrc && python postprocessing/" "consolidate_postprocessing.py -f" % node, folder, "\"") ctr += 1 ================================================ FILE: unetr_pp/postprocessing/consolidate_postprocessing.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from typing import Tuple from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.postprocessing.connected_components import determine_postprocessing import argparse def collect_cv_niftis(cv_folder: str, output_folder: str, validation_folder_name: str = 'validation_raw', folds: tuple = (0, 1, 2, 3, 4)): validation_raw_folders = [join(cv_folder, "fold_%d" % i, validation_folder_name) for i in folds] exist = [isdir(i) for i in validation_raw_folders] if not all(exist): raise RuntimeError("some folds are missing. Please run the full 5-fold cross-validation. " "The following folds seem to be missing: %s" % [i for j, i in enumerate(folds) if not exist[j]]) # now copy all raw niftis into cv_niftis_raw maybe_mkdir_p(output_folder) for f in folds: niftis = subfiles(validation_raw_folders[f], suffix=".nii.gz") for n in niftis: shutil.copy(n, join(output_folder)) def consolidate_folds(output_folder_base, validation_folder_name: str = 'validation_raw', advanced_postprocessing: bool = False, folds: Tuple[int] = (0, 1, 2, 3, 4)): """ Used to determine the postprocessing for an experiment after all five folds have been completed. In the validation of each fold, the postprocessing can only be determined on the cases within that fold. This can result in different postprocessing decisions for different folds. In the end, we can only decide for one postprocessing per experiment, so we have to rerun it :param folds: :param advanced_postprocessing: :param output_folder_base:experiment output folder (fold_0, fold_1, etc must be subfolders of the given folder) :param validation_folder_name: dont use this :return: """ output_folder_raw = join(output_folder_base, "cv_niftis_raw") if isdir(output_folder_raw): shutil.rmtree(output_folder_raw) output_folder_gt = join(output_folder_base, "gt_niftis") collect_cv_niftis(output_folder_base, output_folder_raw, validation_folder_name, folds) num_niftis_gt = len(subfiles(join(output_folder_base, "gt_niftis"), suffix='.nii.gz')) # count niftis in there num_niftis = len(subfiles(output_folder_raw, suffix='.nii.gz')) if num_niftis != num_niftis_gt: raise AssertionError("If does not seem like you trained all the folds! Train all folds first!") # load a summary file so that we can know what class labels to expect summary_fold0 = load_json(join(output_folder_base, "fold_0", validation_folder_name, "summary.json"))['results'][ 'mean'] classes = [int(i) for i in summary_fold0.keys()] niftis = subfiles(output_folder_raw, join=False, suffix=".nii.gz") test_pred_pairs = [(join(output_folder_gt, i), join(output_folder_raw, i)) for i in niftis] # determine_postprocessing needs a summary.json file in the folder where the raw predictions are. We could compute # that from the summary files of the five folds but I am feeling lazy today aggregate_scores(test_pred_pairs, labels=classes, json_output_file=join(output_folder_raw, "summary.json"), num_threads=default_num_threads) determine_postprocessing(output_folder_base, output_folder_gt, 'cv_niftis_raw', final_subf_name="cv_niftis_postprocessed", processes=default_num_threads, advanced_postprocessing=advanced_postprocessing) # determine_postprocessing will create a postprocessing.json file that can be used for inference if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("-f", type=str, required=True, help="experiment output folder (fold_0, fold_1, " "etc must be subfolders of the given folder)") args = argparser.parse_args() folder = args.f consolidate_folds(folder) ================================================ FILE: unetr_pp/postprocessing/consolidate_postprocessing_simple.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from unetr_pp.postprocessing.consolidate_postprocessing import consolidate_folds from unetr_pp.utilities.folder_names import get_output_folder_name from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name from unetr_pp.paths import default_cascade_trainer, default_trainer, default_plans_identifier def main(): argparser = argparse.ArgumentParser(usage="Used to determine the postprocessing for a trained model. Useful for " "when the best configuration (2d, 3d_fullres etc) as selected manually.") argparser.add_argument("-m", type=str, required=True, help="U-Net model (2d, 3d_lowres, 3d_fullres or " "3d_cascade_fullres)") argparser.add_argument("-t", type=str, required=True, help="Task name or id") argparser.add_argument("-tr", type=str, required=False, default=None, help="nnFormerTrainer class. Default: %s, unless 3d_cascade_fullres " "(then it's %s)" % (default_trainer, default_cascade_trainer)) argparser.add_argument("-pl", type=str, required=False, default=default_plans_identifier, help="Plans name, Default=%s" % default_plans_identifier) argparser.add_argument("-val", type=str, required=False, default="validation_raw", help="Validation folder name. Default: validation_raw") args = argparser.parse_args() model = args.m task = args.t trainer = args.tr plans = args.pl val = args.val if not task.startswith("Task"): task_id = int(task) task = convert_id_to_task_name(task_id) if trainer is None: if model == "3d_cascade_fullres": trainer = "nnFormerTrainerV2CascadeFullRes" else: trainer = "nnFormerTrainerV2" folder = get_output_folder_name(model, task, trainer, plans, None) consolidate_folds(folder, val) if __name__ == "__main__": main() ================================================ FILE: unetr_pp/preprocessing/cropping.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import SimpleITK as sitk import numpy as np import shutil from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing import Pool from collections import OrderedDict def create_nonzero_mask(data): from scipy.ndimage import binary_fill_holes assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)" nonzero_mask = np.zeros(data.shape[1:], dtype=bool) for c in range(data.shape[0]): this_mask = data[c] != 0 nonzero_mask = nonzero_mask | this_mask nonzero_mask = binary_fill_holes(nonzero_mask) return nonzero_mask def get_bbox_from_mask(mask, outside_value=0): mask_voxel_coords = np.where(mask != outside_value) minzidx = int(np.min(mask_voxel_coords[0])) maxzidx = int(np.max(mask_voxel_coords[0])) + 1 minxidx = int(np.min(mask_voxel_coords[1])) maxxidx = int(np.max(mask_voxel_coords[1])) + 1 minyidx = int(np.min(mask_voxel_coords[2])) maxyidx = int(np.max(mask_voxel_coords[2])) + 1 return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]] def crop_to_bbox(image, bbox): assert len(image.shape) == 3, "only supports 3d images" resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1])) return image[resizer] def get_case_identifier(case): case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5] return case_identifier def get_case_identifier_from_npz(case): case_identifier = case.split("/")[-1][:-4] return case_identifier def load_case_from_list_of_files(data_files, seg_file=None): assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple" properties = OrderedDict() data_itk = [sitk.ReadImage(f) for f in data_files] properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]] properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]] properties["list_of_data_files"] = data_files properties["seg_file"] = seg_file properties["itk_origin"] = data_itk[0].GetOrigin() properties["itk_spacing"] = data_itk[0].GetSpacing() properties["itk_direction"] = data_itk[0].GetDirection() data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk]) if seg_file is not None: seg_itk = sitk.ReadImage(seg_file) seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32) else: seg_npy = None return data_npy.astype(np.float32), seg_npy, properties def crop_to_nonzero(data, seg=None, nonzero_label=-1): """ :param data: :param seg: :param nonzero_label: this will be written into the segmentation map :return: """ nonzero_mask = create_nonzero_mask(data) bbox = get_bbox_from_mask(nonzero_mask, 0) cropped_data = [] for c in range(data.shape[0]): cropped = crop_to_bbox(data[c], bbox) cropped_data.append(cropped[None]) data = np.vstack(cropped_data) if seg is not None: cropped_seg = [] for c in range(seg.shape[0]): cropped = crop_to_bbox(seg[c], bbox) cropped_seg.append(cropped[None]) seg = np.vstack(cropped_seg) nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None] if seg is not None: seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label else: nonzero_mask = nonzero_mask.astype(int) nonzero_mask[nonzero_mask == 0] = nonzero_label nonzero_mask[nonzero_mask > 0] = 0 seg = nonzero_mask return data, seg, bbox def get_patient_identifiers_from_cropped_files(folder): return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")] class ImageCropper(object): def __init__(self, num_threads, output_folder=None): """ This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask. In the case of BRaTS and ISLES data this results in a significant reduction in image size :param num_threads: :param output_folder: whete to store the cropped data :param list_of_files: """ self.output_folder = output_folder self.num_threads = num_threads if self.output_folder is not None: maybe_mkdir_p(self.output_folder) @staticmethod def crop(data, properties, seg=None): shape_before = data.shape data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1) shape_after = data.shape print("before crop:", shape_before, "after crop:", shape_after, "spacing:", np.array(properties["original_spacing"]), "\n") properties["crop_bbox"] = bbox properties['classes'] = np.unique(seg) seg[seg < -1] = 0 properties["size_after_cropping"] = data[0].shape return data, seg, properties @staticmethod def crop_from_list_of_files(data_files, seg_file=None): data, seg, properties = load_case_from_list_of_files(data_files, seg_file) return ImageCropper.crop(data, properties, seg) def load_crop_save(self, case, case_identifier, overwrite_existing=False): try: print(case_identifier) if overwrite_existing \ or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier)) or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))): data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1]) all_data = np.vstack((data, seg)) np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data) with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f) except Exception as e: print("Exception in", case_identifier, ":") print(e) raise e def get_list_of_cropped_files(self): return subfiles(self.output_folder, join=True, suffix=".npz") def get_patient_identifiers_from_cropped_files(self): return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()] def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None): """ also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation on the cluster :param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]] :param overwrite_existing: :param output_folder: :return: """ if output_folder is not None: self.output_folder = output_folder output_folder_gt = os.path.join(self.output_folder, "gt_segmentations") maybe_mkdir_p(output_folder_gt) for j, case in enumerate(list_of_files): if case[-1] is not None: shutil.copy(case[-1], output_folder_gt) list_of_args = [] for j, case in enumerate(list_of_files): case_identifier = get_case_identifier(case) list_of_args.append((case, case_identifier, overwrite_existing)) p = Pool(self.num_threads) p.starmap(self.load_crop_save, list_of_args) p.close() p.join() def load_properties(self, case_identifier): with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f: properties = pickle.load(f) return properties def save_properties(self, case_identifier, properties): with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f) ================================================ FILE: unetr_pp/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from unetr_pp.preprocessing.preprocessing import PreprocessorFor2D, resample_patient class GenericPreprocessor_scale_uint8_to_0_1(PreprocessorFor2D): """ For RGB images with a value range of [0, 255]. This preprocessor overwrites the default normalization scheme by normalizing intensity values through a simple division by 255 which rescales them to [0, 1] NOTE THAT THIS INHERITS FROM PreprocessorFor2D, SO ITS WRITTEN FOR 2D ONLY! WHEN CREATING A PREPROCESSOR FOR 3D DATA, USE GenericPreprocessor AS PARENT! """ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): ############ THIS PART IS IDENTICAL TO PARENT CLASS ################ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } target_spacing[0] = original_spacing_transposed[0] data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" print("normalization...") ############ HERE IS WHERE WE START CHANGING THINGS!!!!!!!################ # this is where the normalization takes place. We ignore use_nonzero_mask and normalization_scheme_per_modality for c in range(len(data)): data[c] = data[c].astype(np.float32) / 255. return data, seg, properties ================================================ FILE: unetr_pp/preprocessing/preprocessing.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from copy import deepcopy from batchgenerators.augmentations.utils import resize_segmentation from unetr_pp.configuration import default_num_threads, RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD from unetr_pp.preprocessing.cropping import get_case_identifier_from_npz, ImageCropper from skimage.transform import resize from scipy.ndimage.interpolation import map_coordinates import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from multiprocessing.pool import Pool def get_do_separate_z(spacing, anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD): do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold return do_separate_z def get_lowres_axis(new_spacing): axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic return axis def resample_patient(data, seg, original_spacing, target_spacing, order_data=3, order_seg=0, force_separate_z=False, cval_data=0, cval_seg=-1, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD): """ :param cval_seg: :param cval_data: :param data: :param seg: :param original_spacing: :param target_spacing: :param order_data: :param order_seg: :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always /never resample along z separately :param order_z_seg: only applies if do_separate_z is True :param order_z_data: only applies if do_separate_z is True :param separate_z_anisotropy_threshold: if max_spacing > separate_z_anisotropy_threshold * min_spacing (per axis) then resample along lowres axis with order_z_data/order_z_seg instead of order_data/order_seg :return: """ assert not ((data is None) and (seg is None)) if data is not None: assert len(data.shape) == 4, "data must be c x y z" if seg is not None: assert len(seg.shape) == 4, "seg must be c x y z" if data is not None: shape = np.array(data[0].shape) else: shape = np.array(seg[0].shape) new_shape = np.round(((np.array(original_spacing) / np.array(target_spacing)).astype(float) * shape)).astype(int) if force_separate_z is not None: do_separate_z = force_separate_z if force_separate_z: axis = get_lowres_axis(original_spacing) else: axis = None else: if get_do_separate_z(original_spacing, separate_z_anisotropy_threshold): do_separate_z = True axis = get_lowres_axis(original_spacing) elif get_do_separate_z(target_spacing, separate_z_anisotropy_threshold): do_separate_z = True axis = get_lowres_axis(target_spacing) else: do_separate_z = False axis = None if axis is not None: if len(axis) == 3: # every axis has the spacing, this should never happen, why is this code here? do_separate_z = False elif len(axis) == 2: # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample # separately in the out of plane axis do_separate_z = False else: pass if data is not None: data_reshaped = resample_data_or_seg(data, new_shape, False, axis, order_data, do_separate_z, cval=cval_data, order_z=order_z_data) else: data_reshaped = None if seg is not None: seg_reshaped = resample_data_or_seg(seg, new_shape, True, axis, order_seg, do_separate_z, cval=cval_seg, order_z=order_z_seg) else: seg_reshaped = None return data_reshaped, seg_reshaped def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, do_separate_z=False, cval=0, order_z=0): """ separate_z=True will resample with order 0 along z :param data: :param new_shape: :param is_seg: :param axis: :param order: :param do_separate_z: :param cval: :param order_z: only applies if do_separate_z is True :return: """ assert len(data.shape) == 4, "data must be (c, x, y, z)" if is_seg: resize_fn = resize_segmentation kwargs = OrderedDict() else: resize_fn = resize kwargs = {'mode': 'edge', 'anti_aliasing': False} dtype_data = data.dtype data = data.astype(float) shape = np.array(data[0].shape) new_shape = np.array(new_shape) if np.any(shape != new_shape): if do_separate_z: print("separate z, order in z is", order_z, "order inplane is", order) assert len(axis) == 1, "only one anisotropic axis supported" axis = axis[0] if axis == 0: new_shape_2d = new_shape[1:] elif axis == 1: new_shape_2d = new_shape[[0, 2]] else: new_shape_2d = new_shape[:-1] reshaped_final_data = [] for c in range(data.shape[0]): reshaped_data = [] for slice_id in range(shape[axis]): if axis == 0: reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, cval=cval, **kwargs)) elif axis == 1: reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, cval=cval, **kwargs)) else: reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, cval=cval, **kwargs)) reshaped_data = np.stack(reshaped_data, axis) if shape[axis] != new_shape[axis]: # The following few lines are blatantly copied and modified from sklearn's resize() rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] orig_rows, orig_cols, orig_dim = reshaped_data.shape row_scale = float(orig_rows) / rows col_scale = float(orig_cols) / cols dim_scale = float(orig_dim) / dim map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] map_rows = row_scale * (map_rows + 0.5) - 0.5 map_cols = col_scale * (map_cols + 0.5) - 0.5 map_dims = dim_scale * (map_dims + 0.5) - 0.5 coord_map = np.array([map_rows, map_cols, map_dims]) if not is_seg or order_z == 0: reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z, cval=cval, mode='nearest')[None]) else: unique_labels = np.unique(reshaped_data) reshaped = np.zeros(new_shape, dtype=dtype_data) for i, cl in enumerate(unique_labels): reshaped_multihot = np.round( map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z, cval=cval, mode='nearest')) reshaped[reshaped_multihot > 0.5] = cl reshaped_final_data.append(reshaped[None]) else: reshaped_final_data.append(reshaped_data[None]) reshaped_final_data = np.vstack(reshaped_final_data) else: print("no separate z, order", order) reshaped = [] for c in range(data.shape[0]): reshaped.append(resize_fn(data[c], new_shape, order, cval=cval, **kwargs)[None]) reshaped_final_data = np.vstack(reshaped) return reshaped_final_data.astype(dtype_data) else: print("no resampling necessary") return data class GenericPreprocessor(object): def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None): """ :param normalization_scheme_per_modality: dict {0:'nonCT'} :param use_nonzero_mask: {0:False} :param intensityproperties: """ self.transpose_forward = transpose_forward self.intensityproperties = intensityproperties self.normalization_scheme_per_modality = normalization_scheme_per_modality self.use_nonzero_mask = use_nonzero_mask self.resample_separate_z_anisotropy_threshold = RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD @staticmethod def load_cropped(cropped_output_dir, case_identifier): all_data = np.load(os.path.join(cropped_output_dir, "%s.npz" % case_identifier))['data'] data = all_data[:-1].astype(np.float32) seg = all_data[-1:] with open(os.path.join(cropped_output_dir, "%s.pkl" % case_identifier), 'rb') as f: properties = pickle.load(f) return data, seg, properties def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): """ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values (spacing etc) :param data: :param target_spacing: :param properties: :param seg: :param force_separate_z: :return: """ # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it! # data, seg are already transposed. Double check this using the properties original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } # remove nans data[np.isnan(data)] = 0 data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" for c in range(len(data)): scheme = self.normalization_scheme_per_modality[c] if scheme == "CT": # clip to lb and ub from train data foreground and use foreground mn and sd from training data assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" mean_intensity = self.intensityproperties[c]['mean'] std_intensity = self.intensityproperties[c]['sd'] lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] data[c] = np.clip(data[c], lower_bound, upper_bound) data[c] = (data[c] - mean_intensity) / std_intensity if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 elif scheme == "CT2": # clip to lb and ub from train data foreground, use mn and sd form each case for normalization assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] mask = (data[c] > lower_bound) & (data[c] < upper_bound) data[c] = np.clip(data[c], lower_bound, upper_bound) mn = data[c][mask].mean() sd = data[c][mask].std() data[c] = (data[c] - mn) / sd if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 else: if use_nonzero_mask[c]: mask = seg[-1] >= 0 else: mask = np.ones(seg.shape[1:], dtype=bool) data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8) data[c][mask == 0] = 0 return data, seg, properties def preprocess_test_case(self, data_files, target_spacing, seg_file=None, force_separate_z=None): data, seg, properties = ImageCropper.crop_from_list_of_files(data_files, seg_file) data = data.transpose((0, *[i + 1 for i in self.transpose_forward])) seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward])) data, seg, properties = self.resample_and_normalize(data, target_spacing, properties, seg, force_separate_z=force_separate_z) return data.astype(np.float32), seg, properties def _run_internal(self, target_spacing, case_identifier, output_folder_stage, cropped_output_dir, force_separate_z, all_classes): data, seg, properties = self.load_cropped(cropped_output_dir, case_identifier) data = data.transpose((0, *[i + 1 for i in self.transpose_forward])) seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward])) data, seg, properties = self.resample_and_normalize(data, target_spacing, properties, seg, force_separate_z) all_data = np.vstack((data, seg)).astype(np.float32) # we need to find out where the classes are and sample some random locations # let's do 10.000 samples per class # seed this for reproducibility! num_samples = 10000 min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too sparse rndst = np.random.RandomState(1234) class_locs = {} for c in all_classes: all_locs = np.argwhere(all_data[-1] == c) if len(all_locs) == 0: class_locs[c] = [] continue target_num_samples = min(num_samples, len(all_locs)) target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage))) selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)] class_locs[c] = selected print(c, target_num_samples) properties['class_locations'] = class_locs print("saving: ", os.path.join(output_folder_stage, "%s.npz" % case_identifier)) np.savez_compressed(os.path.join(output_folder_stage, "%s.npz" % case_identifier), data=all_data.astype(np.float32)) with open(os.path.join(output_folder_stage, "%s.pkl" % case_identifier), 'wb') as f: pickle.dump(properties, f) def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier, num_threads=default_num_threads, force_separate_z=None): """ :param target_spacings: list of lists [[1.25, 1.25, 5]] :param input_folder_with_cropped_npz: dim: c, x, y, z | npz_file['data'] np.savez_compressed(fname.npz, data=arr) :param output_folder: :param num_threads: :param force_separate_z: None :return: """ print("Initializing to run preprocessing") print("npz folder:", input_folder_with_cropped_npz) print("output_folder:", output_folder) list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True) maybe_mkdir_p(output_folder) num_stages = len(target_spacings) if not isinstance(num_threads, (list, tuple, np.ndarray)): num_threads = [num_threads] * num_stages assert len(num_threads) == num_stages # we need to know which classes are present in this dataset so that we can precompute where these classes are # located. This is needed for oversampling foreground all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes'] for i in range(num_stages): all_args = [] output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i) maybe_mkdir_p(output_folder_stage) spacing = target_spacings[i] for j, case in enumerate(list_of_cropped_npz_files): case_identifier = get_case_identifier_from_npz(case) args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes all_args.append(args) p = Pool(num_threads[i]) p.starmap(self._run_internal, all_args) p.close() p.join() class Preprocessor3DDifferentResampling(GenericPreprocessor): def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): """ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values (spacing etc) :param data: :param target_spacing: :param properties: :param seg: :param force_separate_z: :return: """ # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it! # data, seg are already transposed. Double check this using the properties original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } # remove nans data[np.isnan(data)] = 0 data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=3, order_z_seg=1, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" for c in range(len(data)): scheme = self.normalization_scheme_per_modality[c] if scheme == "CT": # clip to lb and ub from train data foreground and use foreground mn and sd from training data assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" mean_intensity = self.intensityproperties[c]['mean'] std_intensity = self.intensityproperties[c]['sd'] lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] data[c] = np.clip(data[c], lower_bound, upper_bound) data[c] = (data[c] - mean_intensity) / std_intensity if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 elif scheme == "CT2": # clip to lb and ub from train data foreground, use mn and sd form each case for normalization assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] mask = (data[c] > lower_bound) & (data[c] < upper_bound) data[c] = np.clip(data[c], lower_bound, upper_bound) mn = data[c][mask].mean() sd = data[c][mask].std() data[c] = (data[c] - mn) / sd if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 else: if use_nonzero_mask[c]: mask = seg[-1] >= 0 else: mask = np.ones(seg.shape[1:], dtype=bool) data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8) data[c][mask == 0] = 0 return data, seg, properties class Preprocessor3DBetterResampling(GenericPreprocessor): """ This preprocessor always uses force_separate_z=False. It does resampling to the target spacing with third order spline for data (just like GenericPreprocessor) and seg (unlike GenericPreprocessor). It never does separate resampling in z. """ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=False): """ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values (spacing etc) :param data: :param target_spacing: :param properties: :param seg: :param force_separate_z: :return: """ if force_separate_z is not False: print("WARNING: Preprocessor3DBetterResampling always uses force_separate_z=False. " "You specified %s. Your choice is overwritten" % str(force_separate_z)) force_separate_z = False # be safe assert force_separate_z is False # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it! # data, seg are already transposed. Double check this using the properties original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } # remove nans data[np.isnan(data)] = 0 data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 3, force_separate_z=force_separate_z, order_z_data=99999, order_z_seg=99999, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" for c in range(len(data)): scheme = self.normalization_scheme_per_modality[c] if scheme == "CT": # clip to lb and ub from train data foreground and use foreground mn and sd from training data assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" mean_intensity = self.intensityproperties[c]['mean'] std_intensity = self.intensityproperties[c]['sd'] lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] data[c] = np.clip(data[c], lower_bound, upper_bound) data[c] = (data[c] - mean_intensity) / std_intensity if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 elif scheme == "CT2": # clip to lb and ub from train data foreground, use mn and sd form each case for normalization assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] mask = (data[c] > lower_bound) & (data[c] < upper_bound) data[c] = np.clip(data[c], lower_bound, upper_bound) mn = data[c][mask].mean() sd = data[c][mask].std() data[c] = (data[c] - mn) / sd if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 else: if use_nonzero_mask[c]: mask = seg[-1] >= 0 else: mask = np.ones(seg.shape[1:], dtype=bool) data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8) data[c][mask == 0] = 0 return data, seg, properties class PreprocessorFor2D(GenericPreprocessor): def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None): super(PreprocessorFor2D, self).__init__(normalization_scheme_per_modality, use_nonzero_mask, transpose_forward, intensityproperties) def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier, num_threads=default_num_threads, force_separate_z=None): print("Initializing to run preprocessing") print("npz folder:", input_folder_with_cropped_npz) print("output_folder:", output_folder) list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True) assert len(list_of_cropped_npz_files) != 0, "set list of files first" maybe_mkdir_p(output_folder) all_args = [] num_stages = len(target_spacings) # we need to know which classes are present in this dataset so that we can precompute where these classes are # located. This is needed for oversampling foreground all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes'] for i in range(num_stages): output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i) maybe_mkdir_p(output_folder_stage) spacing = target_spacings[i] for j, case in enumerate(list_of_cropped_npz_files): case_identifier = get_case_identifier_from_npz(case) args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes all_args.append(args) p = Pool(num_threads) p.starmap(self._run_internal, all_args) p.close() p.join() def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } target_spacing[0] = original_spacing_transposed[0] data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" print("normalization...") for c in range(len(data)): scheme = self.normalization_scheme_per_modality[c] if scheme == "CT": # clip to lb and ub from train data foreground and use foreground mn and sd from training data assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" mean_intensity = self.intensityproperties[c]['mean'] std_intensity = self.intensityproperties[c]['sd'] lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] data[c] = np.clip(data[c], lower_bound, upper_bound) data[c] = (data[c] - mean_intensity) / std_intensity if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 elif scheme == "CT2": # clip to lb and ub from train data foreground, use mn and sd form each case for normalization assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] mask = (data[c] > lower_bound) & (data[c] < upper_bound) data[c] = np.clip(data[c], lower_bound, upper_bound) mn = data[c][mask].mean() sd = data[c][mask].std() data[c] = (data[c] - mn) / sd if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 else: if use_nonzero_mask[c]: mask = seg[-1] >= 0 else: mask = np.ones(seg.shape[1:], dtype=bool) data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8) data[c][mask == 0] = 0 print("normalization done") return data, seg, properties class PreprocessorFor3D_NoResampling(GenericPreprocessor): def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): """ if target_spacing[0] is None or nan we use original_spacing_transposed[0] (no resampling along z) :param data: :param target_spacing: :param properties: :param seg: :param force_separate_z: :return: """ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } # remove nans data[np.isnan(data)] = 0 target_spacing = deepcopy(original_spacing_transposed) #print(target_spacing, original_spacing_transposed) data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } st = "before:" + str(before) + '\nafter' + str(after) + "\n" print(st) if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" for c in range(len(data)): scheme = self.normalization_scheme_per_modality[c] if scheme == "CT": # clip to lb and ub from train data foreground and use foreground mn and sd from training data assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" mean_intensity = self.intensityproperties[c]['mean'] std_intensity = self.intensityproperties[c]['sd'] lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] data[c] = np.clip(data[c], lower_bound, upper_bound) data[c] = (data[c] - mean_intensity) / std_intensity if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 elif scheme == "CT2": # clip to lb and ub from train data foreground, use mn and sd form each case for normalization assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties" lower_bound = self.intensityproperties[c]['percentile_00_5'] upper_bound = self.intensityproperties[c]['percentile_99_5'] mask = (data[c] > lower_bound) & (data[c] < upper_bound) data[c] = np.clip(data[c], lower_bound, upper_bound) mn = data[c][mask].mean() sd = data[c][mask].std() data[c] = (data[c] - mn) / sd if use_nonzero_mask[c]: data[c][seg[-1] < 0] = 0 else: if use_nonzero_mask[c]: mask = seg[-1] >= 0 else: mask = np.ones(seg.shape[1:], dtype=bool) data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8) data[c][mask == 0] = 0 return data, seg, properties class PreprocessorFor2D_noNormalization(GenericPreprocessor): def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] before = { 'spacing': properties["original_spacing"], 'spacing_transposed': original_spacing_transposed, 'data.shape (data is transposed)': data.shape } target_spacing[0] = original_spacing_transposed[0] data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) after = { 'spacing': target_spacing, 'data.shape (data is resampled)': data.shape } print("before:", before, "\nafter: ", after, "\n") if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? seg[seg < -1] = 0 properties["size_after_resampling"] = data[0].shape properties["spacing_after_resampling"] = target_spacing use_nonzero_mask = self.use_nonzero_mask assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ "must have as many entries as data has " \ "modalities" assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ " has modalities" return data, seg, properties ================================================ FILE: unetr_pp/preprocessing/sanity_checks.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from multiprocessing import Pool import SimpleITK as sitk import nibabel as nib import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads def verify_all_same_orientation(folder): """ This should run after cropping :param folder: :return: """ nii_files = subfiles(folder, suffix=".nii.gz", join=True) orientations = [] for n in nii_files: img = nib.load(n) affine = img.affine orientation = nib.aff2axcodes(affine) orientations.append(orientation) # now we need to check whether they are all the same orientations = np.array(orientations) unique_orientations = np.unique(orientations, axis=0) all_same = len(unique_orientations) == 1 return all_same, unique_orientations def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image): ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(), img_1.GetDirection(), img_1.GetSize() ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(), img_2.GetDirection(), img_2.GetSize() same_ori = np.all(np.isclose(ori1, ori2)) if not same_ori: print("the origin does not match between the images:") print(ori1) print(ori2) same_spac = np.all(np.isclose(spacing1, spacing2)) if not same_spac: print("the spacing does not match between the images") print(spacing1) print(spacing2) same_dir = np.all(np.isclose(direction1, direction2)) if not same_dir: print("the direction does not match between the images") print(direction1) print(direction2) same_size = np.all(np.isclose(size1, size2)) if not same_size: print("the size does not match between the images") print(size1) print(size2) if same_ori and same_spac and same_dir and same_size: return True else: return False def verify_contains_only_expected_labels(itk_img: str, valid_labels: (tuple, list)): img_npy = sitk.GetArrayFromImage(sitk.ReadImage(itk_img)) uniques = np.unique(img_npy) invalid_uniques = [i for i in uniques if i not in valid_labels] if len(invalid_uniques) == 0: r = True else: r = False return r, invalid_uniques def verify_dataset_integrity(folder): """ folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json checks if all training cases and labels are present checks if all test cases (if any) are present for each case, checks whether all modalities apre present for each case, checks whether the pixel grids are aligned checks whether the labels really only contain values they should :param folder: :return: """ assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder dataset = load_json(join(folder, "dataset.json")) training_cases = dataset['training'] num_modalities = len(dataset['modality'].keys()) test_cases = dataset['test'] expected_train_identifiers = [i['image'].split("/")[-1][:-7] for i in training_cases] #image00? expected_test_identifiers = [i.split("/")[-1][:-7] for i in test_cases] ## check training set nii_files_in_imagesTr = subfiles((join(folder, "imagesTr")), suffix=".nii.gz", join=False) nii_files_in_labelsTr = subfiles((join(folder, "labelsTr")), suffix=".nii.gz", join=False) label_files = [] geometries_OK = True has_nan = False # check all cases #imageTr dont have duplicate image if len(expected_train_identifiers) != len(np.unique(expected_train_identifiers)): raise RuntimeError("found duplicate training cases in dataset.json") print("Verifying training set") for c in expected_train_identifiers: print("checking case", c) # check if all files are present expected_label_file = join(folder, "labelsTr", c + ".nii.gz") label_files.append(expected_label_file) expected_image_files = [join(folder, "imagesTr", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)] assert isfile(expected_label_file), "could not find label file for case %s. Expected file: \n%s" % ( c, expected_label_file) assert all([isfile(i) for i in expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % ( c, expected_image_files) # verify that all modalities and the label have the same shape and geometry. label_itk = sitk.ReadImage(expected_label_file) nans_in_seg = np.any(np.isnan(sitk.GetArrayFromImage(label_itk))) has_nan = has_nan | nans_in_seg if nans_in_seg: print("There are NAN values in segmentation %s" % expected_label_file) images_itk = [sitk.ReadImage(i) for i in expected_image_files] for i, img in enumerate(images_itk): nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img))) has_nan = has_nan | nans_in_image same_geometry = verify_same_geometry(img, label_itk) if not same_geometry: geometries_OK = False print("The geometry of the image %s does not match the geometry of the label file. The pixel arrays " "will not be aligned and nnU-Net cannot use this data. Please make sure your image modalities " "are coregistered and have the same geometry as the label" % expected_image_files[0][:-12]) if nans_in_image: print("There are NAN values in image %s" % expected_image_files[i]) # now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr for i in expected_image_files: nii_files_in_imagesTr.remove(os.path.basename(i)) nii_files_in_labelsTr.remove(os.path.basename(expected_label_file)) # check for stragglers assert len( nii_files_in_imagesTr) == 0, "there are training cases in imagesTr that are not listed in dataset.json: %s" % nii_files_in_imagesTr assert len( nii_files_in_labelsTr) == 0, "there are training cases in labelsTr that are not listed in dataset.json: %s" % nii_files_in_labelsTr # verify that only properly declared values are present in the labels print("Verifying label values") expected_labels = list(int(i) for i in dataset['labels'].keys()) # check if labels are in consecutive order assert expected_labels[0] == 0, 'The first label must be 0 and maps to the background' labels_valid_consecutive = np.ediff1d(expected_labels) == 1 assert all(labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction' p = Pool(default_num_threads) results = p.starmap(verify_contains_only_expected_labels, zip(label_files, [expected_labels] * len(label_files))) p.close() p.join() fail = False print("Expected label values are", expected_labels) for i, r in enumerate(results): if not r[0]: print("Unexpected labels found in file %s. Found these unexpected values (they should not be there) %s" % ( label_files[i], r[1])) fail = True if fail: raise AssertionError( "Found unexpected labels in the training dataset. Please correct that or adjust your dataset.json accordingly") else: print("Labels OK") # check test set, but only if there actually is a test set if len(expected_test_identifiers) > 0: print("Verifying test set") nii_files_in_imagesTs = subfiles((join(folder, "imagesTs")), suffix=".nii.gz", join=False) for c in expected_test_identifiers: # check if all files are present expected_image_files = [join(folder, "imagesTs", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)] assert all([isfile(i) for i in expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % ( c, expected_image_files) # verify that all modalities and the label have the same geometry. We use the affine for this if num_modalities > 1: images_itk = [sitk.ReadImage(i) for i in expected_image_files] reference_img = images_itk[0] for i, img in enumerate(images_itk[1:]): assert verify_same_geometry(img, reference_img), "The modalities of the image %s do not seem to be " \ "registered. Please coregister your modalities." % ( expected_image_files[i]) # now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr for i in expected_image_files: nii_files_in_imagesTs.remove(os.path.basename(i)) assert len( nii_files_in_imagesTs) == 0, "there are training cases in imagesTs that are not listed in dataset.json: %s" % nii_files_in_imagesTr all_same, unique_orientations = verify_all_same_orientation(join(folder, "imagesTr")) if not all_same: print( "WARNING: Not all images in the dataset have the same axis ordering. We very strongly recommend you correct that by reorienting the data. fslreorient2std should do the trick") # save unique orientations to dataset.json if not geometries_OK: raise Warning("GEOMETRY MISMATCH FOUND! CHECK THE TEXT OUTPUT! This does not cause an error at this point but you should definitely check whether your geometries are alright!") else: print("Dataset OK") if has_nan: raise RuntimeError("Some images have nan values in them. This will break the training. See text output above to see which ones") def reorient_to_RAS(img_fname: str, output_fname: str = None): img = nib.load(img_fname) canonical_img = nib.as_closest_canonical(img) if output_fname is None: output_fname = img_fname nib.save(canonical_img, output_fname) if __name__ == "__main__": # investigate geometry issues import SimpleITK as sitk # load image gt_itk = sitk.ReadImage( "/media/fabian/Results/nnFormer/3d_fullres/Task064_KiTS_labelsFixed/nnFormerTrainerV2__nnFormerPlansv2.1/gt_niftis/case_00085.nii.gz") # get numpy array pred_npy = sitk.GetArrayFromImage(gt_itk) # create new image from numpy array prek_itk_new = sitk.GetImageFromArray(pred_npy) # copy geometry prek_itk_new.CopyInformation(gt_itk) # prek_itk_new = copy_geometry(prek_itk_new, gt_itk) # save sitk.WriteImage(prek_itk_new, "test.mnc") # load images in nib gt = nib.load( "/media/fabian/Results/nnFormer/3d_fullres/Task064_KiTS_labelsFixed/nnFormerTrainerV2__nnFormerPlansv2.1/gt_niftis/case_00085.nii.gz") pred_nib = nib.load("test.mnc") new_img_sitk = sitk.ReadImage("test.mnc") np1 = sitk.GetArrayFromImage(gt_itk) np2 = sitk.GetArrayFromImage(prek_itk_new) ================================================ FILE: unetr_pp/run/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/run/default_configuration.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unetr_pp from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.experiment_planning.summarize_plans import summarize_plans from unetr_pp.training.model_restore import recursive_find_python_class import numpy as np import pickle def get_configuration_from_output_folder(folder): # split off network_training_output_dir folder = folder[len(network_training_output_dir):] if folder.startswith("/"): folder = folder[1:] configuration, task, trainer_and_plans_identifier = folder.split("/") trainer, plans_identifier = trainer_and_plans_identifier.split("__") return configuration, task, trainer, plans_identifier def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier, search_in=(unetr_pp.__path__[0], "training", "network_training"), base_module='unetr_pp.training.network_training'): assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \ "network can only be one of the following: \'3d\', \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'" dataset_directory = join(preprocessing_output_dir, task) if network == '2d': plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl") else: plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl") plans = load_pickle(plans_file) # Maybe have two kinds of plans,choose the later one if len(plans['plans_per_stage']) == 2: Stage = 1 else: Stage = 0 if task == 'Task001_ACDC': plans['plans_per_stage'][Stage]['batch_size'] = 4 plans['plans_per_stage'][Stage]['patch_size'] = np.array([16, 160, 160]) pickle_file = open(plans_file, 'wb') pickle.dump(plans, pickle_file) pickle_file.close() elif task == 'Task002_Synapse': plans['plans_per_stage'][Stage]['batch_size'] = 2 plans['plans_per_stage'][Stage]['patch_size'] = np.array([64, 128, 128]) plans['plans_per_stage'][Stage]['pool_op_kernel_sizes'] = [[2, 2, 2], [2, 2, 2], [2, 2, 2]] # for deep supervision pickle_file = open(plans_file, 'wb') pickle.dump(plans, pickle_file) pickle_file.close() possible_stages = list(plans['plans_per_stage'].keys()) if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1: raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does " "not require the cascade. Run 3d_fullres instead") if network == '2d' or network == "3d_lowres": stage = 0 else: stage = possible_stages[-1] trainer_class = recursive_find_python_class([join(*search_in)], network_trainer, current_module=base_module) output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier) print("###############################################") print("I am running the following nnFormer: %s" % network) print("My trainer class is: ", trainer_class) print("For that I will be using the following configuration:") summarize_plans(plans_file) print("I am using stage %d from these plans" % stage) if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres': batch_dice = True print("I am using batch dice + CE loss") else: batch_dice = False print("I am using sample dice + CE loss") print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier'])) print("###############################################") return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class ================================================ FILE: unetr_pp/run/run_training.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.run.default_configuration import get_default_configuration from unetr_pp.paths import default_plans_identifier from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name import numpy as np import torch seed = 42 np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True # to improve the efficiency set the last two true def main(): parser = argparse.ArgumentParser() parser.add_argument("network") parser.add_argument("network_trainer") parser.add_argument("task", help="can be task name or task id") parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'') parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation", action="store_true") parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training", action="store_true") parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner", default=default_plans_identifier, required=False) parser.add_argument("--use_compressed_data", default=False, action="store_true", help="If you set use_compressed_data, the training cases will not be decompressed. " "Reading compressed data is much more CPU and RAM intensive and should only be used if " "you know what you are doing", required=False) parser.add_argument("--deterministic", help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think " "this is not necessary. Deterministic training will make you overfit to some random seed. " "Don't use that.", required=False, default=False, action="store_true") parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then model will " "export npz files of " "predicted segmentations " "in the validation as well. " "This is needed to run the " "ensembling (if needed)") parser.add_argument("--find_lr", required=False, default=False, action="store_true", help="not used here, just for fun") parser.add_argument("--valbest", required=False, default=False, action="store_true", help="hands off. This is not intended to be used") parser.add_argument("--fp32", required=False, default=True, action="store_true", help="disable mixed precision training and run old school fp32") parser.add_argument("--val_folder", required=False, default="validation_raw", help="name of the validation folder. No need to use this for most people") parser.add_argument("--disable_saving", required=False, action='store_true', help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that " "will be removed at the end of the training). Useful for development when you are " "only interested in the results and want to save some disk space") parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true', help="Running postprocessing on each fold only makes sense when developing with nnU-Net and " "closely observing the model performance on specific configurations. You do not need it " "when applying nnU-Net because the postprocessing for this will be determined only once " "Usually running postprocessing on each fold is computationally cheap, but some users have " "reported issues with very large images. If your images are large (>600x600x600 voxels) " "you should consider setting this flag.") parser.add_argument('--val_disable_overwrite', action='store_false', default=True, help='Validation does not overwrite existing segmentations') parser.add_argument('--disable_next_stage_pred', action='store_true', default=False, help='do not predict next stage') parser.add_argument('-pretrained_weights', type=str, required=False, default=None, help='path to nnU-Net checkpoint file to be used as pretrained model (use .model ' 'file, for example model_final_checkpoint.model). Will only be used when actually training. ' 'Optional. Beta. Use with caution.') args = parser.parse_args() task = args.task fold = args.fold network = args.network network_trainer = args.network_trainer validation_only = args.validation_only plans_identifier = args.p find_lr = args.find_lr disable_postprocessing_on_folds = args.disable_postprocessing_on_folds use_compressed_data = args.use_compressed_data decompress_data = not use_compressed_data deterministic = args.deterministic valbest = args.valbest fp32 = args.fp32 run_mixed_precision = not fp32 val_folder = args.val_folder if not task.startswith("Task"): task_id = int(task) task = convert_id_to_task_name(task_id) if fold == 'all': pass else: fold = int(fold) plans_file, output_folder_name, dataset_directory, batch_dice, stage, \ trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier) if trainer_class is None: raise RuntimeError("Could not find trainer class in unetr_pp.training.network_training") trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage, unpack_data=decompress_data, deterministic=deterministic, fp16=run_mixed_precision) if args.disable_saving: trainer.save_final_checkpoint = False # whether or not to save the final checkpoint trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to # self.best_val_eval_criterion_MA trainer.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest. We need that in case # the training chashes trainer.save_latest_only = True # if false it will not store/overwrite _latest but separate files each trainer.initialize(not validation_only) if find_lr: trainer.find_lr() else: if not validation_only: if args.continue_training: # -c was set, continue a previous training and ignore pretrained weights trainer.load_latest_checkpoint() else: # new training without pretraine weights, do nothing pass trainer.run_training() else: if valbest: trainer.load_best_checkpoint(train=False) else: trainer.load_final_checkpoint(train=False) trainer.network.eval() # predict validation trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder, run_postprocessing_on_folds=not disable_postprocessing_on_folds, overwrite=args.val_disable_overwrite) if __name__ == "__main__": main() ================================================ FILE: unetr_pp/training/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/training/cascade_stuff/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/training/cascade_stuff/predict_next_stage.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import numpy as np from batchgenerators.utilities.file_and_folder_operations import * import argparse from unetr_pp.preprocessing.preprocessing import resample_data_or_seg from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p import unetr_pp from unetr_pp.run.default_configuration import get_default_configuration from multiprocessing import Pool from unetr_pp.training.model_restore import recursive_find_python_class from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc def resample_and_save(predicted, target_shape, output_file, force_separate_z=False, interpolation_order=1, interpolation_order_z=0): if isinstance(predicted, str): assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \ "isfile(segmentation_softmax) must be True" del_file = deepcopy(predicted) predicted = np.load(predicted) os.remove(del_file) predicted_new_shape = resample_data_or_seg(predicted, target_shape, False, order=interpolation_order, do_separate_z=force_separate_z, cval=0, order_z=interpolation_order_z) seg_new_shape = predicted_new_shape.argmax(0) np.savez_compressed(output_file, data=seg_new_shape.astype(np.uint8)) def predict_next_stage(trainer, stage_to_be_predicted_folder): output_folder = join(pardir(trainer.output_folder), "pred_next_stage") maybe_mkdir_p(output_folder) if 'segmentation_export_params' in trainer.plans.keys(): force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z'] interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 export_pool = Pool(2) results = [] for pat in trainer.dataset_val.keys(): print(pat) data_file = trainer.dataset_val[pat]['data_file'] data_preprocessed = np.load(data_file)['data'][:-1] predicted_probabilities = trainer.predict_preprocessed_data_return_seg_and_softmax( data_preprocessed, do_mirroring=trainer.data_aug_params["do_mirror"], mirror_axes=trainer.data_aug_params['mirror_axes'], mixed_precision=trainer.fp16)[1] data_file_nofolder = data_file.split("/")[-1] data_file_nextstage = join(stage_to_be_predicted_folder, data_file_nofolder) data_nextstage = np.load(data_file_nextstage)['data'] target_shp = data_nextstage.shape[1:] output_file = join(output_folder, data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz") if np.prod(predicted_probabilities.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(output_file[:-4] + ".npy", predicted_probabilities) predicted_probabilities = output_file[:-4] + ".npy" results.append(export_pool.starmap_async(resample_and_save, [(predicted_probabilities, target_shp, output_file, force_separate_z, interpolation_order, interpolation_order_z)])) _ = [i.get() for i in results] export_pool.close() export_pool.join() if __name__ == "__main__": """ RUNNING THIS SCRIPT MANUALLY IS USUALLY NOT NECESSARY. USE THE run_training.py FILE! This script is intended for predicting all the low resolution predictions of 3d_lowres for the next stage of the cascade. It needs to run once for each fold so that the segmentation is only generated for the validation set and not on the data the network was trained on. Run it with python predict_next_stage TRAINERCLASS TASK FOLD""" parser = argparse.ArgumentParser() parser.add_argument("network_trainer") parser.add_argument("task") parser.add_argument("fold", type=int) args = parser.parse_args() trainerclass = args.network_trainer task = args.task fold = args.fold plans_file, folder_with_preprocessed_data, output_folder_name, dataset_directory, batch_dice, stage = \ get_default_configuration("3d_lowres", task) trainer_class = recursive_find_python_class([join(unetr_pp.__path__[0], "training", "network_training")], trainerclass, "unetr_pp.training.network_training") if trainer_class is None: raise RuntimeError("Could not find trainer class in unetr_pp.training.network_training") else: assert issubclass(trainer_class, Trainer_acdc), "network_trainer was found but is not derived from nnFormerTrainer" trainer = trainer_class(plans_file, fold, folder_with_preprocessed_data, output_folder=output_folder_name, dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage) trainer.initialize(False) trainer.load_dataset() trainer.do_split() trainer.load_best_checkpoint(train=False) stage_to_be_predicted_folder = join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1) output_folder = join(pardir(trainer.output_folder), "pred_next_stage") maybe_mkdir_p(output_folder) predict_next_stage(trainer, stage_to_be_predicted_folder) ================================================ FILE: unetr_pp/training/data_augmentation/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/training/data_augmentation/custom_transforms.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from batchgenerators.transforms import AbstractTransform class RemoveKeyTransform(AbstractTransform): def __init__(self, key_to_remove): self.key_to_remove = key_to_remove def __call__(self, **data_dict): _ = data_dict.pop(self.key_to_remove, None) return data_dict class MaskTransform(AbstractTransform): def __init__(self, dct_for_where_it_was_used, mask_idx_in_seg=1, set_outside_to=0, data_key="data", seg_key="seg"): """ data[mask < 0] = 0 Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! :param dct_for_where_it_was_used: :param mask_idx_in_seg: :param set_outside_to: :param data_key: :param seg_key: """ self.dct_for_where_it_was_used = dct_for_where_it_was_used self.seg_key = seg_key self.data_key = data_key self.set_outside_to = set_outside_to self.mask_idx_in_seg = mask_idx_in_seg def __call__(self, **data_dict): seg = data_dict.get(self.seg_key) if seg is None or seg.shape[1] < self.mask_idx_in_seg: raise Warning("mask not found, seg may be missing or seg[:, mask_idx_in_seg] may not exist") data = data_dict.get(self.data_key) for b in range(data.shape[0]): mask = seg[b, self.mask_idx_in_seg] for c in range(data.shape[1]): if self.dct_for_where_it_was_used[c]: data[b, c][mask < 0] = self.set_outside_to data_dict[self.data_key] = data return data_dict def convert_3d_to_2d_generator(data_dict): shp = data_dict['data'].shape data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) data_dict['orig_shape_data'] = shp shp = data_dict['seg'].shape data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) data_dict['orig_shape_seg'] = shp return data_dict def convert_2d_to_3d_generator(data_dict): shp = data_dict['orig_shape_data'] current_shape = data_dict['data'].shape data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], current_shape[-2], current_shape[-1])) shp = data_dict['orig_shape_seg'] current_shape_seg = data_dict['seg'].shape data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1], shp[2], current_shape_seg[-2], current_shape_seg[-1])) return data_dict class Convert3DTo2DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_3d_to_2d_generator(data_dict) class Convert2DTo3DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_2d_to_3d_generator(data_dict) class ConvertSegmentationToRegionsTransform(AbstractTransform): def __init__(self, regions: dict, seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): """ regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, example: regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 :param regions: :param seg_key: :param output_key: """ self.seg_channel = seg_channel self.output_key = output_key self.seg_key = seg_key self.regions = regions def __call__(self, **data_dict): seg = data_dict.get(self.seg_key) num_regions = len(self.regions) if seg is not None: seg_shp = seg.shape output_shape = list(seg_shp) output_shape[1] = num_regions region_output = np.zeros(output_shape, dtype=seg.dtype) for b in range(seg_shp[0]): for r, k in enumerate(self.regions.keys()): for l in self.regions[k]: region_output[b, r][seg[b, self.seg_channel] == l] = 1 data_dict[self.output_key] = region_output return data_dict ================================================ FILE: unetr_pp/training/data_augmentation/data_augmentation_insaneDA.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.dataloading import MultiThreadedAugmenter from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \ GammaTransform, MirrorTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ ContrastAugmentationTransform, BrightnessTransform from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ MaskTransform, ConvertSegmentationToRegionsTransform from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, \ RemoveRandomConnectedComponentFromOneHotEncodingTransform try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"), p_independent_scale_per_axis=params.get("p_independent_scale_per_axis") )) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if params.get("do_additive_brightness"): tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val ================================================ FILE: unetr_pp/training/data_augmentation/data_augmentation_insaneDA2.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.dataloading import MultiThreadedAugmenter from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, \ GammaTransform, MirrorTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ ContrastAugmentationTransform, BrightnessTransform from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ MaskTransform, ConvertSegmentationToRegionsTransform from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, \ RemoveRandomConnectedComponentFromOneHotEncodingTransform from batchgenerators.transforms.spatial_transforms import SpatialTransform_2 try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None def get_insaneDA_augmentation2(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append(SpatialTransform_2( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), deformation_scale=params.get("eldef_deformation_scale"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"), p_independent_scale_per_axis=params.get("p_independent_scale_per_axis") )) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if params.get("do_additive_brightness"): tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) #batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val ================================================ FILE: unetr_pp/training/data_augmentation/data_augmentation_moreDA.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.dataloading import MultiThreadedAugmenter from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \ GammaTransform, MirrorTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ ContrastAugmentationTransform, BrightnessTransform from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ MaskTransform, ConvertSegmentationToRegionsTransform from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, \ RemoveRandomConnectedComponentFromOneHotEncodingTransform try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) if params.get("do_additive_brightness"): tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.1)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") is not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"), p_per_label=params.get("cascade_random_binary_transform_p_per_label"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available') batchgenerator_train = NonDetMultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) else: batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available') batchgenerator_val = NonDetMultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) else: batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) return batchgenerator_train, batchgenerator_val ================================================ FILE: unetr_pp/training/data_augmentation/data_augmentation_noDA.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.dataloading import MultiThreadedAugmenter from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, Compose from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from unetr_pp.training.data_augmentation.custom_transforms import ConvertSegmentationToRegionsTransform from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params from unetr_pp.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=pin_memory) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads') // 2, 1)), pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val ================================================ FILE: unetr_pp/training/data_augmentation/default_data_augmentation.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from copy import deepcopy import numpy as np from batchgenerators.dataloading import MultiThreadedAugmenter from batchgenerators.transforms import DataChannelSelectionTransform, SegChannelSelectionTransform, SpatialTransform, \ GammaTransform, MirrorTransform, Compose from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from unetr_pp.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ MaskTransform, ConvertSegmentationToRegionsTransform from unetr_pp.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, \ RemoveRandomConnectedComponentFromOneHotEncodingTransform try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None default_3D_augmentation_params = { "selected_data_channels": None, "selected_seg_channels": None, "do_elastic": True, "elastic_deform_alpha": (0., 900.), "elastic_deform_sigma": (9., 13.), "p_eldef": 0.2, "do_scaling": True, "scale_range": (0.85, 1.25), "independent_scale_factor_for_each_axis": False, "p_independent_scale_per_axis": 1, "p_scale": 0.2, "do_rotation": True, "rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_p_per_axis": 1, "p_rot": 0.2, "random_crop": False, "random_crop_dist_to_border": None, "do_gamma": True, "gamma_retain_stats": True, "gamma_range": (0.7, 1.5), "p_gamma": 0.3, "do_mirror": True, "mirror_axes": (0, 1, 2), "dummy_2D": False, "mask_was_used_for_normalization": None, "border_mode_data": "constant", "all_segmentation_labels": None, # used for cascade "move_last_seg_chanel_to_data": False, # used for cascade "cascade_do_cascade_augmentations": False, # used for cascade "cascade_random_binary_transform_p": 0.4, "cascade_random_binary_transform_p_per_label": 1, "cascade_random_binary_transform_size": (1, 8), "cascade_remove_conn_comp_p": 0.2, "cascade_remove_conn_comp_max_size_percent_threshold": 0.15, "cascade_remove_conn_comp_fill_with_other_class_p": 0.0, "do_additive_brightness": False, "additive_brightness_p_per_sample": 0.15, "additive_brightness_p_per_channel": 0.5, "additive_brightness_mu": 0.0, "additive_brightness_sigma": 0.1, "num_threads": 12 if 'nnFormer_n_proc_DA' not in os.environ else int(os.environ['nnFormer_n_proc_DA']), "num_cached_per_thread": 1, } default_2D_augmentation_params = deepcopy(default_3D_augmentation_params) default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.) default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.) default_2D_augmentation_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) default_2D_augmentation_params["rotation_y"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) default_2D_augmentation_params["rotation_z"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) # sometimes you have 3d data and a 3d net but cannot augment them properly in 3d due to anisotropy (which is currently # not supported in batchgenerators). In that case you can 'cheat' and transfer your 3d data into 2d data and # transform them back after augmentation default_2D_augmentation_params["dummy_2D"] = False default_2D_augmentation_params["mirror_axes"] = (0, 1) # this can be (0, 1, 2) if dummy_2D=True def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): if isinstance(rot_x, (tuple, list)): rot_x = max(np.abs(rot_x)) if isinstance(rot_y, (tuple, list)): rot_y = max(np.abs(rot_y)) if isinstance(rot_z, (tuple, list)): rot_z = max(np.abs(rot_z)) rot_x = min(90 / 360 * 2. * np.pi, rot_x) rot_y = min(90 / 360 * 2. * np.pi, rot_y) rot_z = min(90 / 360 * 2. * np.pi, rot_z) from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d coords = np.array(final_patch_size) final_shape = np.copy(coords) if len(coords) == 3: final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) elif len(coords) == 2: final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) final_shape /= min(scale_range) return final_shape.astype(int) def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val if __name__ == "__main__": from unetr_pp.training.dataloading.dataset_loading import DataLoader3D, load_dataset from unetr_pp.paths import preprocessing_output_dir import os import pickle t = "Task002_Heart" p = os.path.join(preprocessing_output_dir, t) dataset = load_dataset(p, 0) with open(os.path.join(p, "plans.pkl"), 'rb') as f: plans = pickle.load(f) basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size), default_3D_augmentation_params['rotation_x'], default_3D_augmentation_params['rotation_y'], default_3D_augmentation_params['rotation_z'], default_3D_augmentation_params['scale_range']) dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1) tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int)) ================================================ FILE: unetr_pp/training/data_augmentation/downsampling.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding_batched, resize_segmentation from batchgenerators.transforms import AbstractTransform from torch.nn.functional import avg_pool2d, avg_pool3d import numpy as np class DownsampleSegForDSTransform3(AbstractTransform): ''' returns one hot encodings of the segmentation maps if downsampling has occured (no one hot for highest resolution) downsampled segmentations are smooth, not 0/1 returns torch tensors, not numpy arrays! always uses seg channel 0!! you should always give classes! Otherwise weird stuff may happen ''' def __init__(self, ds_scales=(1, 0.5, 0.25), input_key="seg", output_key="seg", classes=None): self.classes = classes self.output_key = output_key self.input_key = input_key self.ds_scales = ds_scales def __call__(self, **data_dict): data_dict[self.output_key] = downsample_seg_for_ds_transform3(data_dict[self.input_key][:, 0], self.ds_scales, self.classes) return data_dict def downsample_seg_for_ds_transform3(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), classes=None): output = [] one_hot = torch.from_numpy(convert_seg_image_to_one_hot_encoding_batched(seg, classes)) # b, c, for s in ds_scales: if all([i == 1 for i in s]): output.append(torch.from_numpy(seg)) else: kernel_size = tuple(int(1 / i) for i in s) stride = kernel_size pad = tuple((i-1) // 2 for i in kernel_size) if len(s) == 2: pool_op = avg_pool2d elif len(s) == 3: pool_op = avg_pool3d else: raise RuntimeError() pooled = pool_op(one_hot, kernel_size, stride, pad, count_include_pad=False, ceil_mode=False) output.append(pooled) return output class DownsampleSegForDSTransform2(AbstractTransform): ''' data_dict['output_key'] will be a list of segmentations scaled according to ds_scales ''' def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, cval=0, input_key="seg", output_key="seg", axes=None): self.axes = axes self.output_key = output_key self.input_key = input_key self.cval = cval self.order = order self.ds_scales = ds_scales def __call__(self, **data_dict): data_dict[self.output_key] = downsample_seg_for_ds_transform2(data_dict[self.input_key], self.ds_scales, self.order, self.cval, self.axes) return data_dict def downsample_seg_for_ds_transform2(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, cval=0, axes=None): if axes is None: axes = list(range(2, len(seg.shape))) output = [] for s in ds_scales: if all([i == 1 for i in s]): output.append(seg) else: new_shape = np.array(seg.shape).astype(float) for i, a in enumerate(axes): new_shape[a] *= s[i] new_shape = np.round(new_shape).astype(int) out_seg = np.zeros(new_shape, dtype=seg.dtype) for b in range(seg.shape[0]): for c in range(seg.shape[1]): out_seg[b, c] = resize_segmentation(seg[b, c], new_shape[2:], order, cval) output.append(out_seg) return output ================================================ FILE: unetr_pp/training/data_augmentation/pyramid_augmentations.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from skimage.morphology import label, ball from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening import numpy as np from batchgenerators.transforms import AbstractTransform class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform): def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25, dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1): """ :param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25\%! :param channel_idx: can be list or int :param key: """ self.p_per_label = p_per_label self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent self.fill_with_other_class_p = fill_with_other_class_p self.p_per_sample = p_per_sample self.key = key if not isinstance(channel_idx, (list, tuple)): channel_idx = [channel_idx] self.channel_idx = channel_idx def __call__(self, **data_dict): data = data_dict.get(self.key) for b in range(data.shape[0]): if np.random.uniform() < self.p_per_sample: for c in self.channel_idx: if np.random.uniform() < self.p_per_label: workon = np.copy(data[b, c]) num_voxels = np.prod(workon.shape, dtype=np.uint64) lab, num_comp = label(workon, return_num=True) if num_comp > 0: component_ids = [] component_sizes = [] for i in range(1, num_comp + 1): component_ids.append(i) component_sizes.append(np.sum(lab == i)) component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent] #_ = component_ids.pop(np.argmax(component_sizes)) #else: # component_ids = list(range(1, num_comp + 1)) if len(component_ids) > 0: random_component = np.random.choice(component_ids) data[b, c][lab == random_component] = 0 if np.random.uniform() < self.fill_with_other_class_p: other_ch = [i for i in self.channel_idx if i != c] if len(other_ch) > 0: other_class = np.random.choice(other_ch) data[b, other_class][lab == random_component] = 1 data_dict[self.key] = data return data_dict class MoveSegAsOneHotToData(AbstractTransform): def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True): self.remove_from_origin = remove_from_origin self.all_seg_labels = all_seg_labels self.key_target = key_target self.key_origin = key_origin self.channel_id = channel_id def __call__(self, **data_dict): origin = data_dict.get(self.key_origin) target = data_dict.get(self.key_target) seg = origin[:, self.channel_id:self.channel_id+1] seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype) for i, l in enumerate(self.all_seg_labels): seg_onehot[:, i][seg[:, 0] == l] = 1 target = np.concatenate((target, seg_onehot), 1) data_dict[self.key_target] = target if self.remove_from_origin: remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id] origin = origin[:, remaining_channels] data_dict[self.key_origin] = origin return data_dict class ApplyRandomBinaryOperatorTransform(AbstractTransform): def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing, binary_opening), key="data", strel_size=(1, 10), p_per_label=1): self.p_per_label = p_per_label self.strel_size = strel_size self.key = key self.any_of_these = any_of_these self.p_per_sample = p_per_sample assert not isinstance(channel_idx, tuple), "bäh" if not isinstance(channel_idx, list): channel_idx = [channel_idx] self.channel_idx = channel_idx def __call__(self, **data_dict): data = data_dict.get(self.key) for b in range(data.shape[0]): if np.random.uniform() < self.p_per_sample: ch = deepcopy(self.channel_idx) np.random.shuffle(ch) for c in ch: if np.random.uniform() < self.p_per_label: operation = np.random.choice(self.any_of_these) selem = ball(np.random.uniform(*self.strel_size)) workon = np.copy(data[b, c]).astype(int) res = operation(workon, selem).astype(workon.dtype) data[b, c] = res # if class was added, we need to remove it in ALL other channels to keep one hot encoding # properties # we modify data other_ch = [i for i in ch if i != c] if len(other_ch) > 0: was_added_mask = (res - workon) > 0 for oc in other_ch: data[b, oc][was_added_mask] = 0 # if class was removed, leave it at background data_dict[self.key] = data return data_dict class ApplyRandomBinaryOperatorTransform2(AbstractTransform): def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing), key="data", strel_size=(1, 10)): """ 2019_11_22: I have no idea what the purpose of this was... the same as above but here we should use only expanding operations. Expansions will replace other labels :param channel_idx: can be list or int :param p_per_sample: :param any_of_these: :param fill_diff_with_other_class: :param key: :param strel_size: """ self.strel_size = strel_size self.key = key self.any_of_these = any_of_these self.p_per_sample = p_per_sample self.p_per_label = p_per_label assert not isinstance(channel_idx, tuple), "bäh" if not isinstance(channel_idx, list): channel_idx = [channel_idx] self.channel_idx = channel_idx def __call__(self, **data_dict): data = data_dict.get(self.key) for b in range(data.shape[0]): if np.random.uniform() < self.p_per_sample: ch = deepcopy(self.channel_idx) np.random.shuffle(ch) for c in ch: if np.random.uniform() < self.p_per_label: operation = np.random.choice(self.any_of_these) selem = ball(np.random.uniform(*self.strel_size)) workon = np.copy(data[b, c]).astype(int) res = operation(workon, selem).astype(workon.dtype) data[b, c] = res # if class was added, we need to remove it in ALL other channels to keep one hot encoding # properties # we modify data other_ch = [i for i in ch if i != c] if len(other_ch) > 0: was_added_mask = (res - workon) > 0 for oc in other_ch: data[b, oc][was_added_mask] = 0 # if class was removed, leave it at backgound data_dict[self.key] = data return data_dict ================================================ FILE: unetr_pp/training/dataloading/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/training/dataloading/dataset_loading.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from batchgenerators.augmentations.utils import random_crop_2D_image_batched, pad_nd_image import numpy as np from batchgenerators.dataloading import SlimDataLoaderBase from multiprocessing import Pool from unetr_pp.configuration import default_num_threads from unetr_pp.paths import preprocessing_output_dir from batchgenerators.utilities.file_and_folder_operations import * def get_case_identifiers(folder): case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)] return case_identifiers def get_case_identifiers_from_raw_folder(folder): case_identifiers = np.unique( [i[:-12] for i in os.listdir(folder) if i.endswith(".nii.gz") and (i.find("segFromPrevStage") == -1)]) return case_identifiers def convert_to_npy(args): if not isinstance(args, tuple): key = "data" npz_file = args else: npz_file, key = args if not isfile(npz_file[:-3] + "npy"): a = np.load(npz_file)[key] np.save(npz_file[:-3] + "npy", a) def save_as_npz(args): if not isinstance(args, tuple): key = "data" npy_file = args else: npy_file, key = args d = np.load(npy_file) np.savez_compressed(npy_file[:-3] + "npz", **{key: d}) def unpack_dataset(folder, threads=default_num_threads, key="data"): """ unpacks all npz files in a folder to npy (whatever you want to have unpacked must be saved unter key) :param folder: :param threads: :param key: :return: """ p = Pool(threads) npz_files = subfiles(folder, True, None, ".npz", True) p.map(convert_to_npy, zip(npz_files, [key] * len(npz_files))) p.close() p.join() def pack_dataset(folder, threads=default_num_threads, key="data"): p = Pool(threads) npy_files = subfiles(folder, True, None, ".npy", True) p.map(save_as_npz, zip(npy_files, [key] * len(npy_files))) p.close() p.join() def delete_npy(folder): case_identifiers = get_case_identifiers(folder) npy_files = [join(folder, i + ".npy") for i in case_identifiers] npy_files = [i for i in npy_files if isfile(i)] for n in npy_files: os.remove(n) def load_dataset(folder, num_cases_properties_loading_threshold=1000): # we don't load the actual data but instead return the filename to the np file. print('loading dataset') case_identifiers = get_case_identifiers(folder) case_identifiers.sort() dataset = OrderedDict() for c in case_identifiers: dataset[c] = OrderedDict() dataset[c]['data_file'] = join(folder, "%s.npz" % c) # dataset[c]['properties'] = load_pickle(join(folder, "%s.pkl" % c)) dataset[c]['properties_file'] = join(folder, "%s.pkl" % c) if dataset[c].get('seg_from_prev_stage_file') is not None: dataset[c]['seg_from_prev_stage_file'] = join(folder, "%s_segs.npz" % c) if len(case_identifiers) <= num_cases_properties_loading_threshold: print('loading all case properties') for i in dataset.keys(): dataset[i]['properties'] = load_pickle(dataset[i]['properties_file']) return dataset def crop_2D_image_force_fg(img, crop_size, valid_voxels): """ img must be [c, x, y] img[-1] must be the segmentation with segmentation>0 being foreground :param img: :param crop_size: :param valid_voxels: voxels belonging to the selected class :return: """ assert len(valid_voxels.shape) == 2 if type(crop_size) not in (tuple, list): crop_size = [crop_size] * (len(img.shape) - 1) else: assert len(crop_size) == (len( img.shape) - 1), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" # we need to find the center coords that we can crop to without exceeding the image border lb_x = crop_size[0] // 2 ub_x = img.shape[1] - crop_size[0] // 2 - crop_size[0] % 2 lb_y = crop_size[1] // 2 ub_y = img.shape[2] - crop_size[1] // 2 - crop_size[1] % 2 if len(valid_voxels) == 0: selected_center_voxel = (np.random.random_integers(lb_x, ub_x), np.random.random_integers(lb_y, ub_y)) else: selected_center_voxel = valid_voxels[np.random.choice(valid_voxels.shape[1]), :] selected_center_voxel = np.array(selected_center_voxel) for i in range(2): selected_center_voxel[i] = max(crop_size[i] // 2, selected_center_voxel[i]) selected_center_voxel[i] = min(img.shape[i + 1] - crop_size[i] // 2 - crop_size[i] % 2, selected_center_voxel[i]) result = img[:, (selected_center_voxel[0] - crop_size[0] // 2):( selected_center_voxel[0] + crop_size[0] // 2 + crop_size[0] % 2), (selected_center_voxel[1] - crop_size[1] // 2):( selected_center_voxel[1] + crop_size[1] // 2 + crop_size[1] % 2)] return result class DataLoader3D(SlimDataLoaderBase): def __init__(self, data, patch_size, final_patch_size, batch_size, has_prev_stage=False, oversample_foreground_percent=0.0, memmap_mode="r", pad_mode="edge", pad_kwargs_data=None, pad_sides=None): """ This is the basic data loader for 3D networks. It uses preprocessed data as produced by my (Fabian) preprocessing. You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz to npy. Don't forget to call delete_npy(folder) after you are done with training? Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool given that I (Fabian) will have to store that permanently on /datasets and my local computer. With this strategy all data is stored in a compressed format (factor 10 smaller) and only unpacked when needed. :param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go) :param patch_size: what patch size will this data loader return? it is common practice to first load larger patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use get_patch_size() from data_augmentation.default_data_augmentation :param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch size that goes into your network. We need this here because we will pad patients in here so that patches at the border of patients are sampled properly :param batch_size: :param num_batches: how many batches will the data loader produce before stopping? None=endless :param seed: :param stage: ignore this (Fabian only) :param random: Sample keys randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset :param oversample_foreground: half the batch will be forced to contain at least some foreground (equal prob for each of the foreground classes) """ super(DataLoader3D, self).__init__(data, batch_size, None) if pad_kwargs_data is None: pad_kwargs_data = OrderedDict() self.pad_kwargs_data = pad_kwargs_data self.pad_mode = pad_mode self.oversample_foreground_percent = oversample_foreground_percent self.final_patch_size = final_patch_size self.has_prev_stage = has_prev_stage self.patch_size = patch_size self.list_of_keys = list(self._data.keys()) # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size # (which is what the network will get) these patches will also cover the border of the patients self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) if pad_sides is not None: if not isinstance(pad_sides, np.ndarray): pad_sides = np.array(pad_sides) self.need_to_pad += pad_sides self.memmap_mode = memmap_mode self.num_channels = None self.pad_sides = pad_sides self.data_shape, self.seg_shape = self.determine_shapes() def get_do_oversample(self, batch_idx): return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) def determine_shapes(self): if self.has_prev_stage: num_seg = 2 else: num_seg = 1 k = list(self._data.keys())[0] if isfile(self._data[k]['data_file'][:-4] + ".npy"): case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode) else: case_all_data = np.load(self._data[k]['data_file'])['data'] num_color_channels = case_all_data.shape[0] - 1 data_shape = (self.batch_size, num_color_channels, *self.patch_size) seg_shape = (self.batch_size, num_seg, *self.patch_size) return data_shape, seg_shape def generate_train_batch(self): selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None) data = np.zeros(self.data_shape, dtype=np.float32) seg = np.zeros(self.seg_shape, dtype=np.float32) case_properties = [] for j, i in enumerate(selected_keys): # oversampling foreground will improve stability of model training, especially if many patches are empty # (Lung for example) if self.get_do_oversample(j): force_fg = True else: force_fg = False if 'properties' in self._data[i].keys(): properties = self._data[i]['properties'] else: properties = load_pickle(self._data[i]['properties_file']) case_properties.append(properties) # cases are stored as npz, but we require unpack_dataset to be run. This will decompress them into npy # which is much faster to access if isfile(self._data[i]['data_file'][:-4] + ".npy"): case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode) else: case_all_data = np.load(self._data[i]['data_file'])['data'] # If we are doing the cascade then we will also need to load the segmentation of the previous stage and # concatenate it. Here it will be concatenates to the segmentation because the augmentations need to be # applied to it in segmentation mode. Later in the data augmentation we move it from the segmentations to # the last channel of the data if self.has_prev_stage: if isfile(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy"): segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy", mmap_mode=self.memmap_mode)[None] else: segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'])['data'][None] # we theoretically support several possible previsous segmentations from which only one is sampled. But # in practice this feature was never used so it's always only one segmentation seg_key = np.random.choice(segs_from_previous_stage.shape[0]) seg_from_previous_stage = segs_from_previous_stage[seg_key:seg_key + 1] assert all([i == j for i, j in zip(seg_from_previous_stage.shape[1:], case_all_data.shape[1:])]), \ "seg_from_previous_stage does not match the shape of case_all_data: %s vs %s" % \ (str(seg_from_previous_stage.shape[1:]), str(case_all_data.shape[1:])) else: seg_from_previous_stage = None # do you trust me? You better do. Otherwise you'll have to go through this mess and honestly there are # better things you could do right now # (above) documentation of the day. Nice. Even myself coming back 1 months later I have not friggin idea # what's going on. I keep the above documentation just for fun but attempt to make things clearer now need_to_pad = self.need_to_pad for d in range(3): # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides # always if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]: need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1] # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we # define what the upper and lower bound can be to then sample form them with np.random.randint shape = case_all_data.shape[1:] lb_x = - need_to_pad[0] // 2 ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0] lb_y = - need_to_pad[1] // 2 ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1] lb_z = - need_to_pad[2] // 2 ub_z = shape[2] + need_to_pad[2] // 2 + need_to_pad[2] % 2 - self.patch_size[2] # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get # at least one of the foreground classes in the patch if not force_fg: bbox_x_lb = np.random.randint(lb_x, ub_x + 1) bbox_y_lb = np.random.randint(lb_y, ub_y + 1) bbox_z_lb = np.random.randint(lb_z, ub_z + 1) else: # these values should have been precomputed if 'class_locations' not in properties.keys(): raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!") # this saves us a np.unique. Preprocessing already did that for all cases. Neat. foreground_classes = np.array( [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0]) foreground_classes = foreground_classes[foreground_classes > 0] if len(foreground_classes) == 0: # this only happens if some image does not contain foreground voxels at all selected_class = None voxels_of_that_class = None print('case does not contain any foreground classes', i) else: selected_class = np.random.choice(foreground_classes) voxels_of_that_class = properties['class_locations'][selected_class] if voxels_of_that_class is not None: selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. # Make sure it is within the bounds of lb and ub bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2) bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2) bbox_z_lb = max(lb_z, selected_voxel[2] - self.patch_size[2] // 2) else: # If the image does not contain any foreground classes, we fall back to random cropping bbox_x_lb = np.random.randint(lb_x, ub_x + 1) bbox_y_lb = np.random.randint(lb_y, ub_y + 1) bbox_z_lb = np.random.randint(lb_z, ub_z + 1) bbox_x_ub = bbox_x_lb + self.patch_size[0] bbox_y_ub = bbox_y_lb + self.patch_size[1] bbox_z_ub = bbox_z_lb + self.patch_size[2] # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size # later valid_bbox_x_lb = max(0, bbox_x_lb) valid_bbox_x_ub = min(shape[0], bbox_x_ub) valid_bbox_y_lb = max(0, bbox_y_lb) valid_bbox_y_ub = min(shape[1], bbox_y_ub) valid_bbox_z_lb = max(0, bbox_z_lb) valid_bbox_z_ub = min(shape[2], bbox_z_ub) # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also # remove label -1 in the data augmentation but this way it is less error prone) case_all_data = np.copy(case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub, valid_bbox_y_lb:valid_bbox_y_ub, valid_bbox_z_lb:valid_bbox_z_ub]) if seg_from_previous_stage is not None: seg_from_previous_stage = seg_from_previous_stage[:, valid_bbox_x_lb:valid_bbox_x_ub, valid_bbox_y_lb:valid_bbox_y_ub, valid_bbox_z_lb:valid_bbox_z_ub] data[j] = np.pad(case_all_data[:-1], ((0, 0), (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)), (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)), (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))), self.pad_mode, **self.pad_kwargs_data) seg[j, 0] = np.pad(case_all_data[-1:], ((0, 0), (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)), (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)), (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))), 'constant', **{'constant_values': -1}) if seg_from_previous_stage is not None: seg[j, 1] = np.pad(seg_from_previous_stage, ((0, 0), (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)), (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)), (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))), 'constant', **{'constant_values': 0}) return {'data': data, 'seg': seg, 'properties': case_properties, 'keys': selected_keys} class DataLoader2D(SlimDataLoaderBase): def __init__(self, data, patch_size, final_patch_size, batch_size, oversample_foreground_percent=0.0, memmap_mode="r", pseudo_3d_slices=1, pad_mode="edge", pad_kwargs_data=None, pad_sides=None): """ This is the basic data loader for 2D networks. It uses preprocessed data as produced by my (Fabian) preprocessing. You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz to npy. Don't forget to call delete_npy(folder) after you are done with training? Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool given that I (Fabian) will have to store that permanently on /datasets and my local computer. With htis strategy all data is stored in a compressed format (factor 10 smaller) and only unpacked when needed. :param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go) :param patch_size: what patch size will this data loader return? it is common practice to first load larger patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use get_patch_size() from data_augmentation.default_data_augmentation :param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch size that goes into your network. We need this here because we will pad patients in here so that patches at the border of patients are sampled properly :param batch_size: :param num_batches: how many batches will the data loader produce before stopping? None=endless :param seed: :param stage: ignore this (Fabian only) :param transpose: ignore this :param random: sample randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset :param pseudo_3d_slices: 7 = 3 below and 3 above the center slice """ super(DataLoader2D, self).__init__(data, batch_size, None) if pad_kwargs_data is None: pad_kwargs_data = OrderedDict() self.pad_kwargs_data = pad_kwargs_data self.pad_mode = pad_mode self.pseudo_3d_slices = pseudo_3d_slices self.oversample_foreground_percent = oversample_foreground_percent self.final_patch_size = final_patch_size self.patch_size = patch_size self.list_of_keys = list(self._data.keys()) self.need_to_pad = np.array(patch_size) - np.array(final_patch_size) self.memmap_mode = memmap_mode if pad_sides is not None: if not isinstance(pad_sides, np.ndarray): pad_sides = np.array(pad_sides) self.need_to_pad += pad_sides self.pad_sides = pad_sides self.data_shape, self.seg_shape = self.determine_shapes() def determine_shapes(self): num_seg = 1 k = list(self._data.keys())[0] if isfile(self._data[k]['data_file'][:-4] + ".npy"): case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode) else: case_all_data = np.load(self._data[k]['data_file'])['data'] num_color_channels = case_all_data.shape[0] - num_seg data_shape = (self.batch_size, num_color_channels, *self.patch_size) seg_shape = (self.batch_size, num_seg, *self.patch_size) return data_shape, seg_shape def get_do_oversample(self, batch_idx): return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) def generate_train_batch(self): selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None) data = np.zeros(self.data_shape, dtype=np.float32) seg = np.zeros(self.seg_shape, dtype=np.float32) case_properties = [] for j, i in enumerate(selected_keys): if 'properties' in self._data[i].keys(): properties = self._data[i]['properties'] else: properties = load_pickle(self._data[i]['properties_file']) case_properties.append(properties) if self.get_do_oversample(j): force_fg = True else: force_fg = False if not isfile(self._data[i]['data_file'][:-4] + ".npy"): # lets hope you know what you're doing case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npz")['data'] else: case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode) # this is for when there is just a 2d slice in case_all_data (2d support) if len(case_all_data.shape) == 3: case_all_data = case_all_data[:, None] # first select a slice. This can be either random (no force fg) or guaranteed to contain some class if not force_fg: random_slice = np.random.choice(case_all_data.shape[1]) selected_class = None else: # these values should have been precomputed if 'class_locations' not in properties.keys(): raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!") foreground_classes = np.array( [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0]) foreground_classes = foreground_classes[foreground_classes > 0] if len(foreground_classes) == 0: selected_class = None random_slice = np.random.choice(case_all_data.shape[1]) print('case does not contain any foreground classes', i) else: selected_class = np.random.choice(foreground_classes) voxels_of_that_class = properties['class_locations'][selected_class] valid_slices = np.unique(voxels_of_that_class[:, 0]) random_slice = np.random.choice(valid_slices) voxels_of_that_class = voxels_of_that_class[voxels_of_that_class[:, 0] == random_slice] voxels_of_that_class = voxels_of_that_class[:, 1:] # now crop case_all_data to contain just the slice of interest. If we want additional slice above and # below the current slice, here is where we get them. We stack those as additional color channels if self.pseudo_3d_slices == 1: case_all_data = case_all_data[:, random_slice] else: # this is very deprecated and will probably not work anymore. If you intend to use this you need to # check this! mn = random_slice - (self.pseudo_3d_slices - 1) // 2 mx = random_slice + (self.pseudo_3d_slices - 1) // 2 + 1 valid_mn = max(mn, 0) valid_mx = min(mx, case_all_data.shape[1]) case_all_seg = case_all_data[-1:] case_all_data = case_all_data[:-1] case_all_data = case_all_data[:, valid_mn:valid_mx] case_all_seg = case_all_seg[:, random_slice] need_to_pad_below = valid_mn - mn need_to_pad_above = mx - valid_mx if need_to_pad_below > 0: shp_for_pad = np.array(case_all_data.shape) shp_for_pad[1] = need_to_pad_below case_all_data = np.concatenate((np.zeros(shp_for_pad), case_all_data), 1) if need_to_pad_above > 0: shp_for_pad = np.array(case_all_data.shape) shp_for_pad[1] = need_to_pad_above case_all_data = np.concatenate((case_all_data, np.zeros(shp_for_pad)), 1) case_all_data = case_all_data.reshape((-1, case_all_data.shape[-2], case_all_data.shape[-1])) case_all_data = np.concatenate((case_all_data, case_all_seg), 0) # case all data should now be (c, x, y) assert len(case_all_data.shape) == 3 # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we # define what the upper and lower bound can be to then sample form them with np.random.randint need_to_pad = self.need_to_pad for d in range(2): # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides # always if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]: need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1] shape = case_all_data.shape[1:] lb_x = - need_to_pad[0] // 2 ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0] lb_y = - need_to_pad[1] // 2 ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1] # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get # at least one of the foreground classes in the patch if not force_fg or selected_class is None: bbox_x_lb = np.random.randint(lb_x, ub_x + 1) bbox_y_lb = np.random.randint(lb_y, ub_y + 1) else: # this saves us a np.unique. Preprocessing already did that for all cases. Neat. selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. # Make sure it is within the bounds of lb and ub bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2) bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2) bbox_x_ub = bbox_x_lb + self.patch_size[0] bbox_y_ub = bbox_y_lb + self.patch_size[1] # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size # later valid_bbox_x_lb = max(0, bbox_x_lb) valid_bbox_x_ub = min(shape[0], bbox_x_ub) valid_bbox_y_lb = max(0, bbox_y_lb) valid_bbox_y_ub = min(shape[1], bbox_y_ub) # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also # remove label -1 in the data augmentation but this way it is less error prone) case_all_data = case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub, valid_bbox_y_lb:valid_bbox_y_ub] case_all_data_donly = np.pad(case_all_data[:-1], ((0, 0), (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)), (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))), self.pad_mode, **self.pad_kwargs_data) case_all_data_segonly = np.pad(case_all_data[-1:], ((0, 0), (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)), (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))), 'constant', **{'constant_values': -1}) data[j] = case_all_data_donly seg[j] = case_all_data_segonly keys = selected_keys return {'data': data, 'seg': seg, 'properties': case_properties, "keys": keys} if __name__ == "__main__": t = "Task002_Heart" p = join(preprocessing_output_dir, t, "stage1") dataset = load_dataset(p) with open(join(join(preprocessing_output_dir, t), "plans_stage1.pkl"), 'rb') as f: plans = pickle.load(f) unpack_dataset(p) dl = DataLoader3D(dataset, (32, 32, 32), (32, 32, 32), 2, oversample_foreground_percent=0.33) dl = DataLoader3D(dataset, np.array(plans['patch_size']).astype(int), np.array(plans['patch_size']).astype(int), 2, oversample_foreground_percent=0.33) dl2d = DataLoader2D(dataset, (64, 64), np.array(plans['patch_size']).astype(int)[1:], 12, oversample_foreground_percent=0.33) ================================================ FILE: unetr_pp/training/learning_rate/poly_lr.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): return initial_lr * (1 - epoch / max_epochs)**exponent ================================================ FILE: unetr_pp/training/loss_functions/TopK_loss.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from unetr_pp.training.loss_functions.crossentropy import RobustCrossEntropyLoss class TopKLoss(RobustCrossEntropyLoss): """ Network has to have NO LINEARITY! """ def __init__(self, weight=None, ignore_index=-100, k=10): self.k = k super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False) def forward(self, inp, target): target = target[:, 0].long() res = super(TopKLoss, self).forward(inp, target) num_voxels = np.prod(res.shape, dtype=np.int64) res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) return res.mean() ================================================ FILE: unetr_pp/training/loss_functions/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/training/loss_functions/crossentropy.py ================================================ from torch import nn, Tensor class RobustCrossEntropyLoss(nn.CrossEntropyLoss): """ this is just a compatibility layer because my target tensor is float and has an extra dimension """ def forward(self, input: Tensor, target: Tensor) -> Tensor: if len(target.shape) == len(input.shape): assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long()) ================================================ FILE: unetr_pp/training/loss_functions/deep_supervision.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch import nn class MultipleOutputLoss2(nn.Module): def __init__(self, loss, weight_factors=None): """ use this if you have several outputs and ground truth (both list of same len) and the loss should be computed between them (x[0] and y[0], x[1] and y[1] etc) :param loss: :param weight_factors: """ super(MultipleOutputLoss2, self).__init__() self.weight_factors = weight_factors self.loss = loss def forward(self, x, y): assert isinstance(x, (tuple, list)), "x must be either tuple or list" assert isinstance(y, (tuple, list)), "y must be either tuple or list" if self.weight_factors is None: weights = [1] * len(x) else: weights = self.weight_factors l = weights[0] * self.loss(x[0], y[0]) for i in range(1, len(x)): if weights[i] != 0: l += weights[i] * self.loss(x[i], y[i]) return l ================================================ FILE: unetr_pp/training/loss_functions/dice_loss.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from unetr_pp.training.loss_functions.TopK_loss import TopKLoss from unetr_pp.training.loss_functions.crossentropy import RobustCrossEntropyLoss from unetr_pp.utilities.nd_softmax import softmax_helper from unetr_pp.utilities.tensor_utilities import sum_tensor from torch import nn import numpy as np class GDL(nn.Module): def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., square=False, square_volumes=False): """ square_volumes will square the weight term. The paper recommends square_volumes=True; I don't (just an intuition) """ super(GDL, self).__init__() self.square_volumes = square_volumes self.square = square self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: gt = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, gt, 1) if self.apply_nonlin is not None: x = self.apply_nonlin(x) if not self.do_bg: x = x[:, 1:] y_onehot = y_onehot[:, 1:] tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask, self.square) # GDL weight computation, we use 1/V volumes = sum_tensor(y_onehot, axes) + 1e-6 # add some eps to prevent div by zero if self.square_volumes: volumes = volumes ** 2 # apply weights tp = tp / volumes fp = fp / volumes fn = fn / volumes # sum over classes if self.batch_dice: axis = 0 else: axis = 1 tp = tp.sum(axis, keepdim=False) fp = fp.sum(axis, keepdim=False) fn = fn.sum(axis, keepdim=False) # compute dice dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) dc = dc.mean() return -dc def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): """ net_output must be (b, c, x, y(, z))) gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) if mask is provided it must have shape (b, 1, x, y(, z))) :param net_output: :param gt: :param axes: can be (, ) = no summation :param mask: mask must be 1 for valid pixels and 0 for invalid pixels :param square: if True then fp, tp and fn will be squared before summation :return: """ if axes is None: axes = tuple(range(2, len(net_output.size()))) shp_x = net_output.shape shp_y = gt.shape with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x) if net_output.device.type == "cuda": y_onehot = y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1, gt, 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot tn = (1 - net_output) * (1 - y_onehot) if mask is not None: tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) if square: tp = tp ** 2 fp = fp ** 2 fn = fn ** 2 tn = tn ** 2 if len(axes) > 0: tp = sum_tensor(tp, axes, keepdim=False) fp = sum_tensor(fp, axes, keepdim=False) fn = sum_tensor(fn, axes, keepdim=False) tn = sum_tensor(tn, axes, keepdim=False) return tp, fp, fn, tn class SoftDiceLoss(nn.Module): def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.): """ """ super(SoftDiceLoss, self).__init__() self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth def forward(self, x, y, loss_mask=None): shp_x = x.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) nominator = 2 * tp + self.smooth denominator = 2 * tp + fp + fn + self.smooth dc = nominator / (denominator + 1e-8) if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc class MCCLoss(nn.Module): def __init__(self, apply_nonlin=None, batch_mcc=False, do_bg=True, smooth=0.0): """ based on matthews correlation coefficient https://en.wikipedia.org/wiki/Matthews_correlation_coefficient Does not work. Really unstable. F this. """ super(MCCLoss, self).__init__() self.smooth = smooth self.do_bg = do_bg self.batch_mcc = batch_mcc self.apply_nonlin = apply_nonlin def forward(self, x, y, loss_mask=None): shp_x = x.shape voxels = np.prod(shp_x[2:]) if self.batch_mcc: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) tp, fp, fn, tn = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) tp /= voxels fp /= voxels fn /= voxels tn /= voxels nominator = tp * tn - fp * fn + self.smooth denominator = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5 + self.smooth mcc = nominator / denominator if not self.do_bg: if self.batch_mcc: mcc = mcc[1:] else: mcc = mcc[:, 1:] mcc = mcc.mean() return -mcc class SoftDiceLossSquared(nn.Module): def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.): """ squares the terms in the denominator as proposed by Milletari et al. """ super(SoftDiceLossSquared, self).__init__() self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) with torch.no_grad(): if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: y = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, y, 1).float() intersect = x * y_onehot # values in the denominator get smoothed denominator = x ** 2 + y_onehot ** 2 # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after # squaring) intersect = sum_tensor(intersect, axes, False) + self.smooth denominator = sum_tensor(denominator, axes, False) + self.smooth dc = 2 * intersect / denominator if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc class DC_and_CE_loss(nn.Module): def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False, weight_ce=1, weight_dice=1, log_dice=False, ignore_label=None): """ CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want. :param soft_dice_kwargs: :param ce_kwargs: :param aggregate: :param square_dice: :param weight_ce: :param weight_dice: """ super(DC_and_CE_loss, self).__init__() if ignore_label is not None: assert not square_dice, 'not implemented' ce_kwargs['reduction'] = 'none' self.log_dice = log_dice self.weight_dice = weight_dice self.weight_ce = weight_ce self.aggregate = aggregate self.ce = RobustCrossEntropyLoss(**ce_kwargs) self.ignore_label = ignore_label if not square_dice: self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) else: self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs) def forward(self, net_output, target): """ target must be b, c, x, y(, z) with c=1 :param net_output: :param target: :return: """ if self.ignore_label is not None: assert target.shape[1] == 1, 'not implemented for one hot encoding' mask = target != self.ignore_label target[~mask] = 0 mask = mask.float() else: mask = None dc_loss = self.dc(net_output, target, loss_mask=mask) if self.weight_dice != 0 else 0 if self.log_dice: dc_loss = -torch.log(-dc_loss) ce_loss = self.ce(net_output, target[:, 0].long()) if self.weight_ce != 0 else 0 if self.ignore_label is not None: ce_loss *= mask[:, 0] ce_loss = ce_loss.sum() / mask.sum() if self.aggregate == "sum": result = self.weight_ce * ce_loss + self.weight_dice * dc_loss else: raise NotImplementedError("nah son") # reserved for other stuff (later) return result class DC_and_BCE_loss(nn.Module): def __init__(self, bce_kwargs, soft_dice_kwargs, aggregate="sum"): """ DO NOT APPLY NONLINEARITY IN YOUR NETWORK! THIS LOSS IS INTENDED TO BE USED FOR BRATS REGIONS ONLY :param soft_dice_kwargs: :param bce_kwargs: :param aggregate: """ super(DC_and_BCE_loss, self).__init__() self.aggregate = aggregate self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) self.dc = SoftDiceLoss(apply_nonlin=torch.sigmoid, **soft_dice_kwargs) def forward(self, net_output, target): ce_loss = self.ce(net_output, target) dc_loss = self.dc(net_output, target) if self.aggregate == "sum": result = ce_loss + dc_loss else: raise NotImplementedError("nah son") # reserved for other stuff (later) return result class GDL_and_CE_loss(nn.Module): def __init__(self, gdl_dice_kwargs, ce_kwargs, aggregate="sum"): super(GDL_and_CE_loss, self).__init__() self.aggregate = aggregate self.ce = RobustCrossEntropyLoss(**ce_kwargs) self.dc = GDL(softmax_helper, **gdl_dice_kwargs) def forward(self, net_output, target): dc_loss = self.dc(net_output, target) ce_loss = self.ce(net_output, target) if self.aggregate == "sum": result = ce_loss + dc_loss else: raise NotImplementedError("nah son") # reserved for other stuff (later) return result class DC_and_topk_loss(nn.Module): def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False): super(DC_and_topk_loss, self).__init__() self.aggregate = aggregate self.ce = TopKLoss(**ce_kwargs) if not square_dice: self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) else: self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs) def forward(self, net_output, target): dc_loss = self.dc(net_output, target) ce_loss = self.ce(net_output, target) if self.aggregate == "sum": result = ce_loss + dc_loss else: raise NotImplementedError("nah son") # reserved for other stuff (later?) return result ================================================ FILE: unetr_pp/training/model_restore.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unetr_pp import torch from batchgenerators.utilities.file_and_folder_operations import * import importlib import pkgutil from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier def recursive_find_python_class(folder, trainer_name, current_module): tr = None for importer, modname, ispkg in pkgutil.iter_modules(folder): # print(modname, ispkg) if not ispkg: m = importlib.import_module(current_module + "." + modname) if hasattr(m, trainer_name): tr = getattr(m, trainer_name) break if tr is None: for importer, modname, ispkg in pkgutil.iter_modules(folder): if ispkg: next_current_module = current_module + "." + modname tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module) if tr is not None: break return tr def restore_model(pkl_file, checkpoint=None, train=False, fp16=None,folder=None): """ This is a utility function to load any nnFormer trainer from a pkl. It will recursively search unetr_pp.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train). The pkl file required here is the one that will be saved automatically when calling nnFormerTrainer.save_checkpoint. :param pkl_file: :param checkpoint: :param train: :param fp16: if None then we take no action. If True/False we overwrite what the model has in its init :return: """ info = load_pickle(pkl_file) init = info['init'] name = info['name'] ''' update on 2022.6.23. For the model_best.model can be shared in different machines. ''' task=folder.split('/')[-2] network = folder.split('/')[-3] if network == '2d': plans_file = join(preprocessing_output_dir, task, default_plans_identifier + "_plans_2D.pkl") else: plans_file = join(preprocessing_output_dir, task, default_plans_identifier + "_plans_3D.pkl") info['init'] = list(info['init']) info['init'][0]=plans_file info['init'] = tuple(info['init']) if 'nnUNet' in name: name=name.replace('nnUNet','nnFormer') if len(init)>10: init=list(init) del init[2] del init[-2] search_in = join(unetr_pp.__path__[0], "training", "network_training") tr = recursive_find_python_class([search_in], name, current_module="unetr_pp.training.network_training") if tr is None: """ Fabian only. This will trigger searching for trainer classes in other repositories as well """ try: import meddec search_in = join(meddec.__path__[0], "model_training") tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training") except ImportError: pass if tr is None: raise RuntimeError("Could not find the model trainer specified in checkpoint in unetr_pp.trainig.network_training. If it " "is not located there, please move it or change the code of restore_model. Your model " "trainer can be located in any directory within unetr_pp.trainig.network_training (search is recursive)." "\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name)) trainer = tr(*init) # We can hack fp16 overwriting into the trainer without changing the init arguments because nothing happens with # fp16 in the init, it just saves it to a member variable if fp16 is not None: trainer.fp16 = fp16 trainer.process_plans(info['plans']) if checkpoint is not None: trainer.load_checkpoint(checkpoint, train) return trainer def load_best_model_for_inference(folder): checkpoint = join(folder, "model_best.model") pkl_file = checkpoint + ".pkl" return restore_model(pkl_file, checkpoint, False) def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"): """ used for if you need to ensemble the five models of a cross-validation. This will restore the model from the checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast switching between parameters (as opposed to loading them form disk each time). This is best used for inference and test prediction :param folder: :param folds: :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init :return: """ if isinstance(folds, str): folds = [join(folder, "all")] assert isdir(folds[0]), "no output folder for fold %s found" % folds elif isinstance(folds, (list, tuple)): if len(folds) == 1 and folds[0] == "all": folds = [join(folder, "all")] else: folds = [join(folder, "fold_%d" % i) for i in folds] assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present" elif isinstance(folds, int): folds = [join(folder, "fold_%d" % folds)] assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds elif folds is None: print("folds is None so we will automatically look for output folders (not using \'all\'!)") folds = subfolders(folder, prefix="fold") print("found the following folds: ", folds) else: raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds))) trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision,folder=folder) trainer.output_folder = folder trainer.output_folder_base = folder trainer.update_fold(0) trainer.initialize(False) all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds] print("using the following model files: ", all_best_model_files) all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files] return trainer, all_params if __name__ == "__main__": pkl = "/home/fabian/PhD/results/nnFormerV2/nnFormerV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl" checkpoint = pkl[:-4] train = False trainer = restore_model(pkl, checkpoint, train) ================================================ FILE: unetr_pp/training/network_training/Trainer_acdc.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from collections import OrderedDict from multiprocessing import Pool from time import sleep from typing import Tuple, List import matplotlib import unetr_pp import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.postprocessing.connected_components import determine_postprocessing from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ default_2D_augmentation_params, get_default_augmentation, get_patch_size from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss from unetr_pp.training.network_training.network_trainer_acdc import NetworkTrainer_acdc from unetr_pp.utilities.nd_softmax import softmax_helper from unetr_pp.utilities.tensor_utilities import sum_tensor from torch import nn from torch.optim import lr_scheduler matplotlib.use("agg") class Trainer_acdc(NetworkTrainer_acdc): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): """ :param deterministic: :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or None if you wish to load some checkpoint and do inference only :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder, not the entire path). This is where the preprocessed data lies that will be used for network training. We made this explicitly available so that differently preprocessed data can coexist and the user can choose what to use. Can be None if you are doing inference only. :param output_folder: where to store parameters, plot progress and to the validation :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required because the split information is stored in this directory. For running prediction only this input is not required and may be set to None :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the batch is a pseudo volume? :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be specified for training: if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0 :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but is considerably slower! Running unpack_data=False with 2d should never be done! IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args in your init accordingly. Otherwise checkpoints won't load properly! """ super(Trainer_acdc, self).__init__(deterministic, fp16) self.unpack_data = unpack_data self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) # set through arguments from init self.stage = stage self.experiment_name = self.__class__.__name__ self.plans_file = plans_file self.output_folder = output_folder self.dataset_directory = dataset_directory self.output_folder_base = self.output_folder self.fold = fold self.plans = None # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it # irrelevant if self.dataset_directory is not None and isdir(self.dataset_directory): self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations") else: self.gt_niftis_folder = None self.folder_with_preprocessed_data = None # set in self.initialize() self.dl_tr = self.dl_val = None self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \ self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \ self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None self.batch_dice = batch_dice self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \ self.min_region_size_per_class = self.min_size_per_class = None self.inference_pad_border_mode = "constant" self.inference_pad_kwargs = {'constant_values': 0} self.update_fold(fold) self.pad_all_sides = None self.lr_scheduler_eps = 1e-3 self.lr_scheduler_patience = 30 self.initial_lr = 3e-4 self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.conv_per_stage = None self.regions_class_order = None def update_fold(self, fold): """ used to swap between folds for inference (ensemble of models from cross-validation) DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS :param fold: :return: """ if fold is not None: if isinstance(fold, str): assert fold == "all", "if self.fold is a string then it must be \'all\'" if self.output_folder.endswith("%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "%s" % str(fold)) else: if self.output_folder.endswith("fold_%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "fold_%s" % str(fold)) self.fold = fold def setup_DA_params(self): if self.threeD: self.data_aug_params = default_3D_augmentation_params if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() if training: self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: self.print_to_log_file("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) self.print_to_log_file("done") else: self.print_to_log_file( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) self.was_initialized = True def initialize_network(self): """ This is specific to the U-Net and must be adapted for other network architectures :return: """ # self.print_to_log_file(self.net_num_pool_op_kernel_sizes) # self.print_to_log_file(self.net_conv_kernel_sizes) net_numpool = len(self.net_num_pool_op_kernel_sizes) if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool, self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) self.network.inference_apply_nonlin = softmax_helper if torch.cuda.is_available(): self.network.cuda() def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True) self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=self.lr_scheduler_patience, verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs") def plot_network_architecture(self): try: from batchgenerators.utilities.file_and_folder_operations import join import hiddenlayer as hl if torch.cuda.is_available(): g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(), transforms=None) else: g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)), transforms=None) g.save(join(self.output_folder, "network_architecture.pdf")) del g except Exception as e: self.print_to_log_file("Unable to plot network architecture:") self.print_to_log_file(e) self.print_to_log_file("\nprinting the network instead:\n") self.print_to_log_file(self.network) self.print_to_log_file("\n") finally: if torch.cuda.is_available(): torch.cuda.empty_cache() def save_debug_information(self): # saving some debug information dct = OrderedDict() for k in self.__dir__(): if not k.startswith("__"): if not callable(getattr(self, k)): dct[k] = str(getattr(self, k)) del dct['plans'] del dct['intensity_properties'] del dct['dataset'] del dct['dataset_tr'] del dct['dataset_val'] save_json(dct, join(self.output_folder, "debug.json")) import shutil shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl")) def run_training(self): self.save_debug_information() super(Trainer_acdc, self).run_training() def load_plans_file(self): """ This is what actually configures the entire experiment. The plans file is generated by experiment planning :return: """ self.plans = load_pickle(self.plans_file) def process_plans(self, plans): if self.stage is None: assert len(list(plans['plans_per_stage'].keys())) == 1, \ "If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \ "case. Please specify which stage of the cascade must be trained" self.stage = list(plans['plans_per_stage'].keys())[0] self.plans = plans stage_plans = self.plans['plans_per_stage'][self.stage] self.batch_size = stage_plans['batch_size'] self.net_pool_per_axis = stage_plans['num_pool_per_axis'] self.patch_size = np.array(stage_plans['patch_size']).astype(int) self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug'] if 'pool_op_kernel_sizes' not in stage_plans.keys(): assert 'num_pool_per_axis' in stage_plans.keys() self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...") self.net_num_pool_op_kernel_sizes = [] for i in range(max(self.net_pool_per_axis)): curr = [] for j in self.net_pool_per_axis: if (max(self.net_pool_per_axis) - j) <= i: curr.append(2) else: curr.append(1) self.net_num_pool_op_kernel_sizes.append(curr) else: self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes'] if 'conv_kernel_sizes' not in stage_plans.keys(): self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...") self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1) else: self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes'] self.pad_all_sides = None # self.patch_size self.intensity_properties = plans['dataset_properties']['intensityproperties'] self.normalization_schemes = plans['normalization_schemes'] self.base_num_features = plans['base_num_features'] self.num_input_channels = plans['num_modalities'] self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes self.classes = plans['all_classes'] self.use_mask_for_norm = plans['use_mask_for_norm'] self.only_keep_largest_connected_component = plans['keep_only_largest_region'] self.min_region_size_per_class = plans['min_region_size_per_class'] self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class'] if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None: print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " "You should rerun preprocessing. We will proceed and assume that both transpose_foward " "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!") plans['transpose_forward'] = [0, 1, 2] plans['transpose_backward'] = [0, 1, 2] self.transpose_forward = plans['transpose_forward'] self.transpose_backward = plans['transpose_backward'] if len(self.patch_size) == 2: self.threeD = False elif len(self.patch_size) == 3: self.threeD = True else: raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size)) if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently self.conv_per_stage = plans['conv_per_stage'] else: self.conv_per_stage = 2 def load_dataset(self): self.dataset = load_dataset(self.folder_with_preprocessed_data) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') else: dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') return dl_tr, dl_val def preprocess_patient(self, input_files): """ Used to predict new unseen data. Not used for the preprocessing of the training/test data :param input_files: :return: """ from unetr_pp.training.model_restore import recursive_find_python_class preprocessor_name = self.plans.get('preprocessor_name') if preprocessor_name is None: if self.threeD: preprocessor_name = "GenericPreprocessor" else: preprocessor_name = "PreprocessorFor2D" print("using preprocessor", preprocessor_name) preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")], preprocessor_name, current_module="unetr_pp.preprocessing") assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \ preprocessor_name preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm, self.transpose_forward, self.intensity_properties) d, s, properties = preprocessor.preprocess_test_case(input_files, self.plans['plans_per_stage'][self.stage][ 'current_spacing']) return d, s, properties def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None, softmax_ouput_file: str = None, mixed_precision: bool = True) -> None: """ Use this to predict new data :param input_files: :param output_file: :param softmax_ouput_file: :param mixed_precision: :return: """ print("preprocessing...") d, s, properties = self.preprocess_patient(input_files) print("predicting...") pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"], mirror_axes=self.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=0.5, use_gaussian=True, pad_border_mode='constant', pad_kwargs={'constant_values': 0}, verbose=True, all_in_gpu=False, mixed_precision=mixed_precision)[1] pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 print("resampling to original spacing and nifti export...") save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order, self.regions_class_order, None, None, softmax_ouput_file, None, force_separate_z=force_separate_z, interpolation_order_z=interpolation_order_z) print("done") def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ :param data: :param do_mirroring: :param mirror_axes: :param use_sliding_window: :param step_size: :param use_gaussian: :param pad_border_mode: :param pad_kwargs: :param all_in_gpu: :param verbose: :return: """ if pad_border_mode == 'constant' and pad_kwargs is None: pad_kwargs = {'constant_values': 0} if do_mirroring and mirror_axes is None: mirror_axes = self.data_aug_params['mirror_axes'] if do_mirroring: assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ "was done without mirroring" valid = list((SegmentationNetwork, nn.DataParallel)) assert isinstance(self.network, tuple(valid)) current_mode = self.network.training self.network.eval() ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, patch_size=self.patch_size, regions_class_order=self.regions_class_order, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.train(current_mode) return ret def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ if debug=True then the temporary files generated for postprocessing determination will be kept """ current_mode = self.network.training self.network.eval() assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" if self.dataset_val is None: self.load_dataset() self.do_split() if segmentation_export_kwargs is None: if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] # predictions as they come from the network go here output_folder = join(self.output_folder, validation_folder_name) maybe_mkdir_p(output_folder) # this is for debug purposes my_input_args = {'do_mirroring': do_mirroring, 'use_sliding_window': use_sliding_window, 'step_size': step_size, 'save_softmax': save_softmax, 'use_gaussian': use_gaussian, 'overwrite': overwrite, 'validation_folder_name': validation_folder_name, 'debug': debug, 'all_in_gpu': all_in_gpu, 'segmentation_export_kwargs': segmentation_export_kwargs, } save_json(my_input_args, join(output_folder, "validation_args.json")) if do_mirroring: if not self.data_aug_params['do_mirror']: raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") mirror_axes = self.data_aug_params['mirror_axes'] else: mirror_axes = () pred_gt_tuples = [] export_pool = Pool(default_num_threads) results = [] for k in self.dataset_val.keys(): properties = load_pickle(self.dataset[k]['properties_file']) fname = properties['list_of_data_files'][0].split("/")[-1][:-12] if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ (save_softmax and not isfile(join(output_folder, fname + ".npz"))): data = np.load(self.dataset[k]['data_file'])['data'] print(k, data.shape) data[-1][data[-1] == -1] = 0 softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1], do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, all_in_gpu=all_in_gpu, mixed_precision=self.fp16)[1] softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if save_softmax: softmax_fname = join(output_folder, fname + ".npz") else: softmax_fname = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(join(output_folder, fname + ".npy"), softmax_pred) softmax_pred = join(output_folder, fname + ".npy") results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_pred, join(output_folder, fname + ".nii.gz"), properties, interpolation_order, self.regions_class_order, None, None, softmax_fname, None, force_separate_z, interpolation_order_z), ) ) ) pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), join(self.gt_niftis_folder, fname + ".nii.gz")]) _ = [i.get() for i in results] self.print_to_log_file("finished prediction") # evaluate raw predictions self.print_to_log_file("evaluation of raw predictions") task = self.dataset_directory.split("/")[-1] job_name = self.experiment_name _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), json_output_file=join(output_folder, "summary.json"), json_name=job_name + " val tiled %s" % (str(use_sliding_window)), json_author="Fabian", json_task=task, num_threads=default_num_threads) if run_postprocessing_on_folds: # in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything # except the largest connected component for each class. To see if this improves results, we do this for all # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will # have this applied during inference as well self.print_to_log_file("determining postprocessing") determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, final_subf_name=validation_folder_name + "_postprocessed", debug=debug) # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed" # They are always in that folder, even if no postprocessing as applied! # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to # be used later gt_nifti_folder = join(self.output_folder_base, "gt_niftis") maybe_mkdir_p(gt_nifti_folder) for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): success = False attempts = 0 e = None while not success and attempts < 10: try: shutil.copy(f, gt_nifti_folder) success = True except OSError as e: attempts += 1 sleep(1) if not success: print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) if e is not None: raise e self.network.train(current_mode) def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def finish_online_evaluation(self): self.online_eval_tp = np.sum(self.online_eval_tp, 0) self.online_eval_fp = np.sum(self.online_eval_fp, 0) self.online_eval_fn = np.sum(self.online_eval_fn, 0) global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)] if not np.isnan(i)] self.all_val_eval_metrics.append(np.mean(global_dc_per_class)) self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class)) self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not " "exact.)") self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] def save_checkpoint(self, fname, save_optimizer=True): super(Trainer_acdc, self).save_checkpoint(fname, save_optimizer) info = OrderedDict() info['init'] = self.init_args info['name'] = self.__class__.__name__ info['class'] = str(self.__class__) info['plans'] = self.plans write_pickle(info, fname + ".pkl") ================================================ FILE: unetr_pp/training/network_training/Trainer_lung.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from collections import OrderedDict from multiprocessing import Pool from time import sleep from typing import Tuple, List import matplotlib import unetr_pp import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.postprocessing.connected_components import determine_postprocessing from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ default_2D_augmentation_params, get_default_augmentation, get_patch_size from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss from unetr_pp.training.network_training.network_trainer_lung import NetworkTrainer_lung from unetr_pp.utilities.nd_softmax import softmax_helper from unetr_pp.utilities.tensor_utilities import sum_tensor from torch import nn from torch.optim import lr_scheduler matplotlib.use("agg") class Trainer_lung(NetworkTrainer_lung): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): """ :param deterministic: :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or None if you wish to load some checkpoint and do inference only :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder, not the entire path). This is where the preprocessed data lies that will be used for network training. We made this explicitly available so that differently preprocessed data can coexist and the user can choose what to use. Can be None if you are doing inference only. :param output_folder: where to store parameters, plot progress and to the validation :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required because the split information is stored in this directory. For running prediction only this input is not required and may be set to None :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the batch is a pseudo volume? :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be specified for training: if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0 :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but is considerably slower! Running unpack_data=False with 2d should never be done! IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args in your init accordingly. Otherwise checkpoints won't load properly! """ super(Trainer_lung, self).__init__(deterministic, fp16) self.unpack_data = unpack_data self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) # set through arguments from init self.stage = stage self.experiment_name = self.__class__.__name__ self.plans_file = plans_file self.output_folder = output_folder self.dataset_directory = dataset_directory self.output_folder_base = self.output_folder self.fold = fold self.plans = None # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it # irrelevant if self.dataset_directory is not None and isdir(self.dataset_directory): self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations") else: self.gt_niftis_folder = None self.folder_with_preprocessed_data = None # set in self.initialize() self.dl_tr = self.dl_val = None self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \ self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \ self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None self.batch_dice = batch_dice self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \ self.min_region_size_per_class = self.min_size_per_class = None self.inference_pad_border_mode = "constant" self.inference_pad_kwargs = {'constant_values': 0} self.update_fold(fold) self.pad_all_sides = None self.lr_scheduler_eps = 1e-3 self.lr_scheduler_patience = 30 self.initial_lr = 3e-4 self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.conv_per_stage = None self.regions_class_order = None def update_fold(self, fold): """ used to swap between folds for inference (ensemble of models from cross-validation) DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS :param fold: :return: """ if fold is not None: if isinstance(fold, str): assert fold == "all", "if self.fold is a string then it must be \'all\'" if self.output_folder.endswith("%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "%s" % str(fold)) else: if self.output_folder.endswith("fold_%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "fold_%s" % str(fold)) self.fold = fold def setup_DA_params(self): if self.threeD: self.data_aug_params = default_3D_augmentation_params if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() if training: self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: self.print_to_log_file("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) self.print_to_log_file("done") else: self.print_to_log_file( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) self.was_initialized = True def initialize_network(self): """ This is specific to the U-Net and must be adapted for other network architectures :return: """ # self.print_to_log_file(self.net_num_pool_op_kernel_sizes) # self.print_to_log_file(self.net_conv_kernel_sizes) net_numpool = len(self.net_num_pool_op_kernel_sizes) if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool, self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) self.network.inference_apply_nonlin = softmax_helper if torch.cuda.is_available(): self.network.cuda() def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True) self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=self.lr_scheduler_patience, verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs") def plot_network_architecture(self): try: from batchgenerators.utilities.file_and_folder_operations import join import hiddenlayer as hl if torch.cuda.is_available(): g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(), transforms=None) else: g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)), transforms=None) g.save(join(self.output_folder, "network_architecture.pdf")) del g except Exception as e: self.print_to_log_file("Unable to plot network architecture:") self.print_to_log_file(e) self.print_to_log_file("\nprinting the network instead:\n") self.print_to_log_file(self.network) self.print_to_log_file("\n") finally: if torch.cuda.is_available(): torch.cuda.empty_cache() def save_debug_information(self): # saving some debug information dct = OrderedDict() for k in self.__dir__(): if not k.startswith("__"): if not callable(getattr(self, k)): dct[k] = str(getattr(self, k)) del dct['plans'] del dct['intensity_properties'] del dct['dataset'] del dct['dataset_tr'] del dct['dataset_val'] save_json(dct, join(self.output_folder, "debug.json")) import shutil shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl")) def run_training(self): self.save_debug_information() super(Trainer_lung, self).run_training() def load_plans_file(self): """ This is what actually configures the entire experiment. The plans file is generated by experiment planning :return: """ self.plans = load_pickle(self.plans_file) def process_plans(self, plans): if self.stage is None: assert len(list(plans['plans_per_stage'].keys())) == 1, \ "If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \ "case. Please specify which stage of the cascade must be trained" self.stage = list(plans['plans_per_stage'].keys())[0] self.plans = plans stage_plans = self.plans['plans_per_stage'][self.stage] self.batch_size = stage_plans['batch_size'] self.net_pool_per_axis = stage_plans['num_pool_per_axis'] self.patch_size = np.array(stage_plans['patch_size']).astype(int) self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug'] if 'pool_op_kernel_sizes' not in stage_plans.keys(): assert 'num_pool_per_axis' in stage_plans.keys() self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...") self.net_num_pool_op_kernel_sizes = [] for i in range(max(self.net_pool_per_axis)): curr = [] for j in self.net_pool_per_axis: if (max(self.net_pool_per_axis) - j) <= i: curr.append(2) else: curr.append(1) self.net_num_pool_op_kernel_sizes.append(curr) else: self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes'] if 'conv_kernel_sizes' not in stage_plans.keys(): self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...") self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1) else: self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes'] self.pad_all_sides = None # self.patch_size self.intensity_properties = plans['dataset_properties']['intensityproperties'] self.normalization_schemes = plans['normalization_schemes'] self.base_num_features = plans['base_num_features'] self.num_input_channels = plans['num_modalities'] self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes self.classes = plans['all_classes'] self.use_mask_for_norm = plans['use_mask_for_norm'] self.only_keep_largest_connected_component = plans['keep_only_largest_region'] self.min_region_size_per_class = plans['min_region_size_per_class'] self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class'] if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None: print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " "You should rerun preprocessing. We will proceed and assume that both transpose_foward " "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!") plans['transpose_forward'] = [0, 1, 2] plans['transpose_backward'] = [0, 1, 2] self.transpose_forward = plans['transpose_forward'] self.transpose_backward = plans['transpose_backward'] if len(self.patch_size) == 2: self.threeD = False elif len(self.patch_size) == 3: self.threeD = True else: raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size)) if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently self.conv_per_stage = plans['conv_per_stage'] else: self.conv_per_stage = 2 def load_dataset(self): self.dataset = load_dataset(self.folder_with_preprocessed_data) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') else: dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') return dl_tr, dl_val def preprocess_patient(self, input_files): """ Used to predict new unseen data. Not used for the preprocessing of the training/test data :param input_files: :return: """ from unetr_pp.training.model_restore import recursive_find_python_class preprocessor_name = self.plans.get('preprocessor_name') if preprocessor_name is None: if self.threeD: preprocessor_name = "GenericPreprocessor" else: preprocessor_name = "PreprocessorFor2D" print("using preprocessor", preprocessor_name) preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")], preprocessor_name, current_module="unetr_pp.preprocessing") assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \ preprocessor_name preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm, self.transpose_forward, self.intensity_properties) d, s, properties = preprocessor.preprocess_test_case(input_files, self.plans['plans_per_stage'][self.stage][ 'current_spacing']) return d, s, properties def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None, softmax_ouput_file: str = None, mixed_precision: bool = True) -> None: """ Use this to predict new data :param input_files: :param output_file: :param softmax_ouput_file: :param mixed_precision: :return: """ print("preprocessing...") d, s, properties = self.preprocess_patient(input_files) print("predicting...") pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"], mirror_axes=self.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=0.5, use_gaussian=True, pad_border_mode='constant', pad_kwargs={'constant_values': 0}, verbose=True, all_in_gpu=False, mixed_precision=mixed_precision)[1] pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 print("resampling to original spacing and nifti export...") save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order, self.regions_class_order, None, None, softmax_ouput_file, None, force_separate_z=force_separate_z, interpolation_order_z=interpolation_order_z) print("done") def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ :param data: :param do_mirroring: :param mirror_axes: :param use_sliding_window: :param step_size: :param use_gaussian: :param pad_border_mode: :param pad_kwargs: :param all_in_gpu: :param verbose: :return: """ if pad_border_mode == 'constant' and pad_kwargs is None: pad_kwargs = {'constant_values': 0} if do_mirroring and mirror_axes is None: mirror_axes = self.data_aug_params['mirror_axes'] if do_mirroring: assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ "was done without mirroring" valid = list((SegmentationNetwork, nn.DataParallel)) assert isinstance(self.network, tuple(valid)) current_mode = self.network.training self.network.eval() ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, patch_size=self.patch_size, regions_class_order=self.regions_class_order, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.train(current_mode) return ret def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ if debug=True then the temporary files generated for postprocessing determination will be kept """ current_mode = self.network.training self.network.eval() assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" if self.dataset_val is None: self.load_dataset() self.do_split() if segmentation_export_kwargs is None: if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] # predictions as they come from the network go here output_folder = join(self.output_folder, validation_folder_name) maybe_mkdir_p(output_folder) # this is for debug purposes my_input_args = {'do_mirroring': do_mirroring, 'use_sliding_window': use_sliding_window, 'step_size': step_size, 'save_softmax': save_softmax, 'use_gaussian': use_gaussian, 'overwrite': overwrite, 'validation_folder_name': validation_folder_name, 'debug': debug, 'all_in_gpu': all_in_gpu, 'segmentation_export_kwargs': segmentation_export_kwargs, } save_json(my_input_args, join(output_folder, "validation_args.json")) if do_mirroring: if not self.data_aug_params['do_mirror']: raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") mirror_axes = self.data_aug_params['mirror_axes'] else: mirror_axes = () pred_gt_tuples = [] export_pool = Pool(default_num_threads) results = [] for k in self.dataset_val.keys(): properties = load_pickle(self.dataset[k]['properties_file']) fname = properties['list_of_data_files'][0].split("/")[-1][:-12] if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ (save_softmax and not isfile(join(output_folder, fname + ".npz"))): data = np.load(self.dataset[k]['data_file'])['data'] print(k, data.shape) data[-1][data[-1] == -1] = 0 softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1], do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, all_in_gpu=all_in_gpu, mixed_precision=self.fp16)[1] softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if save_softmax: softmax_fname = join(output_folder, fname + ".npz") else: softmax_fname = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(join(output_folder, fname + ".npy"), softmax_pred) softmax_pred = join(output_folder, fname + ".npy") results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_pred, join(output_folder, fname + ".nii.gz"), properties, interpolation_order, self.regions_class_order, None, None, softmax_fname, None, force_separate_z, interpolation_order_z), ) ) ) pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), join(self.gt_niftis_folder, fname + ".nii.gz")]) _ = [i.get() for i in results] self.print_to_log_file("finished prediction") # evaluate raw predictions self.print_to_log_file("evaluation of raw predictions") task = self.dataset_directory.split("/")[-1] job_name = self.experiment_name _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), json_output_file=join(output_folder, "summary.json"), json_name=job_name + " val tiled %s" % (str(use_sliding_window)), json_author="Fabian", json_task=task, num_threads=default_num_threads) if run_postprocessing_on_folds: # in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything # except the largest connected component for each class. To see if this improves results, we do this for all # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will # have this applied during inference as well self.print_to_log_file("determining postprocessing") determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, final_subf_name=validation_folder_name + "_postprocessed", debug=debug) # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed" # They are always in that folder, even if no postprocessing as applied! # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to # be used later gt_nifti_folder = join(self.output_folder_base, "gt_niftis") maybe_mkdir_p(gt_nifti_folder) for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): success = False attempts = 0 e = None while not success and attempts < 10: try: shutil.copy(f, gt_nifti_folder) success = True except OSError as e: attempts += 1 sleep(1) if not success: print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) if e is not None: raise e self.network.train(current_mode) def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def finish_online_evaluation(self): self.online_eval_tp = np.sum(self.online_eval_tp, 0) self.online_eval_fp = np.sum(self.online_eval_fp, 0) self.online_eval_fn = np.sum(self.online_eval_fn, 0) global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)] if not np.isnan(i)] self.all_val_eval_metrics.append(np.mean(global_dc_per_class)) self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class)) self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not " "exact.)") self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] def save_checkpoint(self, fname, save_optimizer=True): super(Trainer_lung, self).save_checkpoint(fname, save_optimizer) info = OrderedDict() info['init'] = self.init_args info['name'] = self.__class__.__name__ info['class'] = str(self.__class__) info['plans'] = self.plans write_pickle(info, fname + ".pkl") ================================================ FILE: unetr_pp/training/network_training/Trainer_synapse.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil import time from collections import OrderedDict from multiprocessing import Pool from time import sleep from typing import Tuple, List import matplotlib import unetr_pp import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.postprocessing.connected_components import determine_postprocessing from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ default_2D_augmentation_params, get_default_augmentation, get_patch_size from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss from unetr_pp.training.network_training.network_trainer_synapse import NetworkTrainer_synapse from unetr_pp.utilities.nd_softmax import softmax_helper from unetr_pp.utilities.tensor_utilities import sum_tensor from torch import nn from torch.optim import lr_scheduler matplotlib.use("agg") class Trainer_synapse(NetworkTrainer_synapse): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): """ :param deterministic: :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or None if you wish to load some checkpoint and do inference only :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder, not the entire path). This is where the preprocessed data lies that will be used for network training. We made this explicitly available so that differently preprocessed data can coexist and the user can choose what to use. Can be None if you are doing inference only. :param output_folder: where to store parameters, plot progress and to the validation :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required because the split information is stored in this directory. For running prediction only this input is not required and may be set to None :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the batch is a pseudo volume? :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be specified for training: if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0 :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but is considerably slower! Running unpack_data=False with 2d should never be done! IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args in your init accordingly. Otherwise checkpoints won't load properly! """ super(Trainer_synapse, self).__init__(deterministic, fp16) self.unpack_data = unpack_data self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) # set through arguments from init self.stage = stage self.experiment_name = self.__class__.__name__ self.plans_file = plans_file self.output_folder = output_folder self.dataset_directory = dataset_directory self.output_folder_base = self.output_folder self.fold = fold self.plans = None # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it # irrelevant if self.dataset_directory is not None and isdir(self.dataset_directory): self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations") else: self.gt_niftis_folder = None self.folder_with_preprocessed_data = None # set in self.initialize() self.dl_tr = self.dl_val = None self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \ self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \ self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None self.batch_dice = batch_dice self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \ self.min_region_size_per_class = self.min_size_per_class = None self.inference_pad_border_mode = "constant" self.inference_pad_kwargs = {'constant_values': 0} self.update_fold(fold) self.pad_all_sides = None self.lr_scheduler_eps = 1e-3 self.lr_scheduler_patience = 30 self.initial_lr = 3e-4 self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.conv_per_stage = None self.regions_class_order = None def update_fold(self, fold): """ used to swap between folds for inference (ensemble of models from cross-validation) DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS :param fold: :return: """ if fold is not None: if isinstance(fold, str): assert fold == "all", "if self.fold is a string then it must be \'all\'" if self.output_folder.endswith("%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "%s" % str(fold)) else: if self.output_folder.endswith("fold_%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "fold_%s" % str(fold)) self.fold = fold def setup_DA_params(self): if self.threeD: self.data_aug_params = default_3D_augmentation_params if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() if training: self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: self.print_to_log_file("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) self.print_to_log_file("done") else: self.print_to_log_file( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) self.was_initialized = True def initialize_network(self): """ This is specific to the U-Net and must be adapted for other network architectures :return: """ # self.print_to_log_file(self.net_num_pool_op_kernel_sizes) # self.print_to_log_file(self.net_conv_kernel_sizes) net_numpool = len(self.net_num_pool_op_kernel_sizes) if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool, self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) self.network.inference_apply_nonlin = softmax_helper if torch.cuda.is_available(): self.network.cuda() def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True) self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=self.lr_scheduler_patience, verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs") def plot_network_architecture(self): try: from batchgenerators.utilities.file_and_folder_operations import join import hiddenlayer as hl if torch.cuda.is_available(): g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(), transforms=None) else: g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)), transforms=None) g.save(join(self.output_folder, "network_architecture.pdf")) del g except Exception as e: self.print_to_log_file("Unable to plot network architecture:") self.print_to_log_file(e) self.print_to_log_file("\nprinting the network instead:\n") self.print_to_log_file(self.network) self.print_to_log_file("\n") finally: if torch.cuda.is_available(): torch.cuda.empty_cache() def save_debug_information(self): # saving some debug information dct = OrderedDict() for k in self.__dir__(): if not k.startswith("__"): if not callable(getattr(self, k)): dct[k] = str(getattr(self, k)) del dct['plans'] del dct['intensity_properties'] del dct['dataset'] del dct['dataset_tr'] del dct['dataset_val'] save_json(dct, join(self.output_folder, "debug.json")) import shutil shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl")) def run_training(self): self.save_debug_information() super(Trainer_synapse, self).run_training() def load_plans_file(self): """ This is what actually configures the entire experiment. The plans file is generated by experiment planning :return: """ self.plans = load_pickle(self.plans_file) def process_plans(self, plans): if self.stage is None: assert len(list(plans['plans_per_stage'].keys())) == 1, \ "If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \ "case. Please specify which stage of the cascade must be trained" self.stage = list(plans['plans_per_stage'].keys())[0] self.plans = plans stage_plans = self.plans['plans_per_stage'][self.stage] self.batch_size = stage_plans['batch_size'] self.net_pool_per_axis = stage_plans['num_pool_per_axis'] self.patch_size = np.array(stage_plans['patch_size']).astype(int) self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug'] if 'pool_op_kernel_sizes' not in stage_plans.keys(): assert 'num_pool_per_axis' in stage_plans.keys() self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...") self.net_num_pool_op_kernel_sizes = [] for i in range(max(self.net_pool_per_axis)): curr = [] for j in self.net_pool_per_axis: if (max(self.net_pool_per_axis) - j) <= i: curr.append(2) else: curr.append(1) self.net_num_pool_op_kernel_sizes.append(curr) else: self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes'] if 'conv_kernel_sizes' not in stage_plans.keys(): self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...") self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1) else: self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes'] self.pad_all_sides = None # self.patch_size self.intensity_properties = plans['dataset_properties']['intensityproperties'] self.normalization_schemes = plans['normalization_schemes'] self.base_num_features = plans['base_num_features'] self.num_input_channels = plans['num_modalities'] self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes self.classes = plans['all_classes'] self.use_mask_for_norm = plans['use_mask_for_norm'] self.only_keep_largest_connected_component = plans['keep_only_largest_region'] self.min_region_size_per_class = plans['min_region_size_per_class'] self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class'] if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None: print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " "You should rerun preprocessing. We will proceed and assume that both transpose_foward " "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!") plans['transpose_forward'] = [0, 1, 2] plans['transpose_backward'] = [0, 1, 2] self.transpose_forward = plans['transpose_forward'] self.transpose_backward = plans['transpose_backward'] if len(self.patch_size) == 2: self.threeD = False elif len(self.patch_size) == 3: self.threeD = True else: raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size)) if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently self.conv_per_stage = plans['conv_per_stage'] else: self.conv_per_stage = 2 def load_dataset(self): self.dataset = load_dataset(self.folder_with_preprocessed_data) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') else: dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') return dl_tr, dl_val def preprocess_patient(self, input_files): """ Used to predict new unseen data. Not used for the preprocessing of the training/test data :param input_files: :return: """ from unetr_pp.training.model_restore import recursive_find_python_class preprocessor_name = self.plans.get('preprocessor_name') if preprocessor_name is None: if self.threeD: preprocessor_name = "GenericPreprocessor" else: preprocessor_name = "PreprocessorFor2D" print("using preprocessor", preprocessor_name) preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")], preprocessor_name, current_module="unetr_pp.preprocessing") assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \ preprocessor_name preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm, self.transpose_forward, self.intensity_properties) d, s, properties = preprocessor.preprocess_test_case(input_files, self.plans['plans_per_stage'][self.stage][ 'current_spacing']) return d, s, properties def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None, softmax_ouput_file: str = None, mixed_precision: bool = True) -> None: """ Use this to predict new data :param input_files: :param output_file: :param softmax_ouput_file: :param mixed_precision: :return: """ print("preprocessing...") d, s, properties = self.preprocess_patient(input_files) print("predicting...") pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"], mirror_axes=self.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=0.5, use_gaussian=True, pad_border_mode='constant', pad_kwargs={'constant_values': 0}, verbose=True, all_in_gpu=False, mixed_precision=mixed_precision)[1] pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 print("resampling to original spacing and nifti export...") save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order, self.regions_class_order, None, None, softmax_ouput_file, None, force_separate_z=force_separate_z, interpolation_order_z=interpolation_order_z) print("done") def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ :param data: :param do_mirroring: :param mirror_axes: :param use_sliding_window: :param step_size: :param use_gaussian: :param pad_border_mode: :param pad_kwargs: :param all_in_gpu: :param verbose: :return: """ if pad_border_mode == 'constant' and pad_kwargs is None: pad_kwargs = {'constant_values': 0} if do_mirroring and mirror_axes is None: mirror_axes = self.data_aug_params['mirror_axes'] if do_mirroring: assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ "was done without mirroring" valid = list((SegmentationNetwork, nn.DataParallel)) assert isinstance(self.network, tuple(valid)) current_mode = self.network.training self.network.eval() ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, patch_size=self.patch_size, regions_class_order=self.regions_class_order, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.train(current_mode) return ret def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ if debug=True then the temporary files generated for postprocessing determination will be kept """ current_mode = self.network.training self.network.eval() assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" if self.dataset_val is None: self.load_dataset() self.do_split() if segmentation_export_kwargs is None: if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] # predictions as they come from the network go here output_folder = join(self.output_folder, validation_folder_name) maybe_mkdir_p(output_folder) # this is for debug purposes my_input_args = {'do_mirroring': do_mirroring, 'use_sliding_window': use_sliding_window, 'step_size': step_size, 'save_softmax': save_softmax, 'use_gaussian': use_gaussian, 'overwrite': overwrite, 'validation_folder_name': validation_folder_name, 'debug': debug, 'all_in_gpu': all_in_gpu, 'segmentation_export_kwargs': segmentation_export_kwargs, } save_json(my_input_args, join(output_folder, "validation_args.json")) if do_mirroring: if not self.data_aug_params['do_mirror']: raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") mirror_axes = self.data_aug_params['mirror_axes'] else: mirror_axes = () pred_gt_tuples = [] export_pool = Pool(default_num_threads) results = [] inference_times = [] for k in self.dataset_val.keys(): properties = load_pickle(self.dataset[k]['properties_file']) fname = properties['list_of_data_files'][0].split("/")[-1][:-12] if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ (save_softmax and not isfile(join(output_folder, fname + ".npz"))): start = time.time() data = np.load(self.dataset[k]['data_file'])['data'] print(k, data.shape) data[-1][data[-1] == -1] = 0 softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1], do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, all_in_gpu=all_in_gpu, mixed_precision=self.fp16)[1] softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) end = time.time() print(f"Inference Time (k={k}): {end - start} sec.") inference_times.append(end-start) if save_softmax: softmax_fname = join(output_folder, fname + ".npz") else: softmax_fname = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(join(output_folder, fname + ".npy"), softmax_pred) softmax_pred = join(output_folder, fname + ".npy") results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_pred, join(output_folder, fname + ".nii.gz"), properties, interpolation_order, self.regions_class_order, None, None, softmax_fname, None, force_separate_z, interpolation_order_z), ) ) ) pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), join(self.gt_niftis_folder, fname + ".nii.gz")]) # Print inference times print(f"Total Images: {len(inference_times)}") print(f"Total Time: {sum(inference_times)} sec.") print(f"Time per Image: {sum(inference_times) / len(inference_times)} sec.") _ = [i.get() for i in results] self.print_to_log_file("finished prediction") # evaluate raw predictions self.print_to_log_file("evaluation of raw predictions") task = self.dataset_directory.split("/")[-1] job_name = self.experiment_name _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), json_output_file=join(output_folder, "summary.json"), json_name=job_name + " val tiled %s" % (str(use_sliding_window)), json_author="Fabian", json_task=task, num_threads=default_num_threads) if run_postprocessing_on_folds: # in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything # except the largest connected component for each class. To see if this improves results, we do this for all # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will # have this applied during inference as well self.print_to_log_file("determining postprocessing") determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, final_subf_name=validation_folder_name + "_postprocessed", debug=debug) # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed" # They are always in that folder, even if no postprocessing as applied! # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to # be used later gt_nifti_folder = join(self.output_folder_base, "gt_niftis") maybe_mkdir_p(gt_nifti_folder) for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): success = False attempts = 0 e = None while not success and attempts < 10: try: shutil.copy(f, gt_nifti_folder) success = True except OSError as e: attempts += 1 sleep(1) if not success: print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) if e is not None: raise e self.network.train(current_mode) def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) #tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) #fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) #fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) tp_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index) fp_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index) fn_hard = torch.zeros((target.shape[0], 8)).to(output_seg.device.index) i=0 for c in range(1, num_classes): if c in [10,12,13,5,9]: continue tp_hard[:, i] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, i] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, i] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) i+=1 tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def finish_online_evaluation(self): self.online_eval_tp = np.sum(self.online_eval_tp, 0) self.online_eval_fp = np.sum(self.online_eval_fp, 0) self.online_eval_fn = np.sum(self.online_eval_fn, 0) global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)] if not np.isnan(i)] self.all_val_eval_metrics.append(np.mean(global_dc_per_class)) self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class)) self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not " "exact.)") self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] def save_checkpoint(self, fname, save_optimizer=True): super(Trainer_synapse, self).save_checkpoint(fname, save_optimizer) info = OrderedDict() info['init'] = self.init_args info['name'] = self.__class__.__name__ info['class'] = str(self.__class__) info['plans'] = self.plans write_pickle(info, fname + ".pkl") ================================================ FILE: unetr_pp/training/network_training/Trainer_tumor.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil import time from collections import OrderedDict from multiprocessing import Pool from time import sleep from typing import Tuple, List import matplotlib import unetr_pp import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.configuration import default_num_threads from unetr_pp.evaluation.evaluator import aggregate_scores from unetr_pp.inference.segmentation_export import save_segmentation_nifti_from_softmax from unetr_pp.network_architecture.generic_UNet import Generic_UNet from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.postprocessing.connected_components import determine_postprocessing from unetr_pp.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \ default_2D_augmentation_params, get_default_augmentation, get_patch_size from unetr_pp.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset from unetr_pp.training.loss_functions.dice_loss import DC_and_CE_loss from unetr_pp.training.network_training.network_trainer_tumor import NetworkTrainer_tumor from unetr_pp.utilities.nd_softmax import softmax_helper from unetr_pp.utilities.tensor_utilities import sum_tensor from torch import nn from torch.optim import lr_scheduler matplotlib.use("agg") class Trainer_tumor(NetworkTrainer_tumor): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): """ :param deterministic: :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or None if you wish to load some checkpoint and do inference only :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder, not the entire path). This is where the preprocessed data lies that will be used for network training. We made this explicitly available so that differently preprocessed data can coexist and the user can choose what to use. Can be None if you are doing inference only. :param output_folder: where to store parameters, plot progress and to the validation :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required because the split information is stored in this directory. For running prediction only this input is not required and may be set to None :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the batch is a pseudo volume? :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be specified for training: if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0 :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but is considerably slower! Running unpack_data=False with 2d should never be done! IMPORTANT: If you inherit from nnFormerTrainer and the init args change then you need to redefine self.init_args in your init accordingly. Otherwise checkpoints won't load properly! """ super(Trainer_tumor, self).__init__(deterministic, fp16) self.unpack_data = unpack_data self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) # set through arguments from init self.stage = stage self.experiment_name = self.__class__.__name__ self.plans_file = plans_file self.output_folder = output_folder self.dataset_directory = dataset_directory self.output_folder_base = self.output_folder self.fold = fold self.plans = None # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it # irrelevant if self.dataset_directory is not None and isdir(self.dataset_directory): self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations") else: self.gt_niftis_folder = None self.folder_with_preprocessed_data = None # set in self.initialize() self.dl_tr = self.dl_val = None self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \ self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \ self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None self.batch_dice = batch_dice self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}) self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \ self.min_region_size_per_class = self.min_size_per_class = None self.inference_pad_border_mode = "constant" self.inference_pad_kwargs = {'constant_values': 0} self.update_fold(fold) self.pad_all_sides = None self.lr_scheduler_eps = 1e-3 self.lr_scheduler_patience = 30 self.initial_lr = 3e-4 self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.conv_per_stage = None self.regions_class_order = None def update_fold(self, fold): """ used to swap between folds for inference (ensemble of models from cross-validation) DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS :param fold: :return: """ if fold is not None: if isinstance(fold, str): assert fold == "all", "if self.fold is a string then it must be \'all\'" if self.output_folder.endswith("%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "%s" % str(fold)) else: if self.output_folder.endswith("fold_%s" % str(self.fold)): self.output_folder = self.output_folder_base self.output_folder = join(self.output_folder, "fold_%s" % str(fold)) self.fold = fold def setup_DA_params(self): if self.threeD: self.data_aug_params = default_3D_augmentation_params if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() if training: self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: self.print_to_log_file("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) self.print_to_log_file("done") else: self.print_to_log_file( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) self.was_initialized = True def initialize_network(self): """ This is specific to the U-Net and must be adapted for other network architectures :return: """ # self.print_to_log_file(self.net_num_pool_op_kernel_sizes) # self.print_to_log_file(self.net_conv_kernel_sizes) net_numpool = len(self.net_num_pool_op_kernel_sizes) if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool, self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) self.network.inference_apply_nonlin = softmax_helper if torch.cuda.is_available(): self.network.cuda() def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True) self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=self.lr_scheduler_patience, verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs") def plot_network_architecture(self): try: from batchgenerators.utilities.file_and_folder_operations import join import hiddenlayer as hl if torch.cuda.is_available(): g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(), transforms=None) else: g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)), transforms=None) g.save(join(self.output_folder, "network_architecture.pdf")) del g except Exception as e: self.print_to_log_file("Unable to plot network architecture:") self.print_to_log_file(e) self.print_to_log_file("\nprinting the network instead:\n") self.print_to_log_file(self.network) self.print_to_log_file("\n") finally: if torch.cuda.is_available(): torch.cuda.empty_cache() def save_debug_information(self): # saving some debug information dct = OrderedDict() for k in self.__dir__(): if not k.startswith("__"): if not callable(getattr(self, k)): dct[k] = str(getattr(self, k)) del dct['plans'] del dct['intensity_properties'] del dct['dataset'] del dct['dataset_tr'] del dct['dataset_val'] save_json(dct, join(self.output_folder, "debug.json")) import shutil shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl")) def run_training(self): self.save_debug_information() super(Trainer_tumor, self).run_training() def load_plans_file(self): """ This is what actually configures the entire experiment. The plans file is generated by experiment planning :return: """ self.plans = load_pickle(self.plans_file) #self.plans = load_pickle("../../../DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/Task003_tumor/unetr_pp_Plansv2.1_plans_3D.pkl") def process_plans(self, plans): if self.stage is None: assert len(list(plans['plans_per_stage'].keys())) == 1, \ "If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \ "case. Please specify which stage of the cascade must be trained" self.stage = list(plans['plans_per_stage'].keys())[0] self.plans = plans stage_plans = self.plans['plans_per_stage'][self.stage] self.batch_size = stage_plans['batch_size'] self.net_pool_per_axis = stage_plans['num_pool_per_axis'] self.patch_size = np.array(stage_plans['patch_size']).astype(int) self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug'] if 'pool_op_kernel_sizes' not in stage_plans.keys(): assert 'num_pool_per_axis' in stage_plans.keys() self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...") self.net_num_pool_op_kernel_sizes = [] for i in range(max(self.net_pool_per_axis)): curr = [] for j in self.net_pool_per_axis: if (max(self.net_pool_per_axis) - j) <= i: curr.append(2) else: curr.append(1) self.net_num_pool_op_kernel_sizes.append(curr) else: self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes'] if 'conv_kernel_sizes' not in stage_plans.keys(): self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...") self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1) else: self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes'] self.pad_all_sides = None # self.patch_size self.intensity_properties = plans['dataset_properties']['intensityproperties'] self.normalization_schemes = plans['normalization_schemes'] self.base_num_features = plans['base_num_features'] self.num_input_channels = plans['num_modalities'] self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes self.classes = plans['all_classes'] self.use_mask_for_norm = plans['use_mask_for_norm'] self.only_keep_largest_connected_component = plans['keep_only_largest_region'] self.min_region_size_per_class = plans['min_region_size_per_class'] self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class'] if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None: print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. " "You should rerun preprocessing. We will proceed and assume that both transpose_foward " "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!") plans['transpose_forward'] = [0, 1, 2] plans['transpose_backward'] = [0, 1, 2] self.transpose_forward = plans['transpose_forward'] self.transpose_backward = plans['transpose_backward'] if len(self.patch_size) == 2: self.threeD = False elif len(self.patch_size) == 3: self.threeD = True else: raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size)) if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently self.conv_per_stage = plans['conv_per_stage'] else: self.conv_per_stage = 2 def load_dataset(self): self.dataset = load_dataset(self.folder_with_preprocessed_data) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') else: dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, oversample_foreground_percent=self.oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r') return dl_tr, dl_val def preprocess_patient(self, input_files): """ Used to predict new unseen data. Not used for the preprocessing of the training/test data :param input_files: :return: """ from unetr_pp.training.model_restore import recursive_find_python_class preprocessor_name = self.plans.get('preprocessor_name') if preprocessor_name is None: if self.threeD: preprocessor_name = "GenericPreprocessor" else: preprocessor_name = "PreprocessorFor2D" print("using preprocessor", preprocessor_name) preprocessor_class = recursive_find_python_class([join(unetr_pp.__path__[0], "preprocessing")], preprocessor_name, current_module="unetr_pp.preprocessing") assert preprocessor_class is not None, "Could not find preprocessor %s in unetr_pp.preprocessing" % \ preprocessor_name preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm, self.transpose_forward, self.intensity_properties) d, s, properties = preprocessor.preprocess_test_case(input_files, self.plans['plans_per_stage'][self.stage][ 'current_spacing']) return d, s, properties def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None, softmax_ouput_file: str = None, mixed_precision: bool = True) -> None: """ Use this to predict new data :param input_files: :param output_file: :param softmax_ouput_file: :param mixed_precision: :return: """ print("preprocessing...") d, s, properties = self.preprocess_patient(input_files) print("predicting...") pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"], mirror_axes=self.data_aug_params['mirror_axes'], use_sliding_window=True, step_size=0.5, use_gaussian=True, pad_border_mode='constant', pad_kwargs={'constant_values': 0}, verbose=True, all_in_gpu=False, mixed_precision=mixed_precision)[1] pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward]) if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 print("resampling to original spacing and nifti export...") save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order, self.regions_class_order, None, None, softmax_ouput_file, None, force_separate_z=force_separate_z, interpolation_order_z=interpolation_order_z) print("done") def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ :param data: :param do_mirroring: :param mirror_axes: :param use_sliding_window: :param step_size: :param use_gaussian: :param pad_border_mode: :param pad_kwargs: :param all_in_gpu: :param verbose: :return: """ if pad_border_mode == 'constant' and pad_kwargs is None: pad_kwargs = {'constant_values': 0} if do_mirroring and mirror_axes is None: mirror_axes = self.data_aug_params['mirror_axes'] if do_mirroring: assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ "was done without mirroring" valid = list((SegmentationNetwork, nn.DataParallel)) assert isinstance(self.network, tuple(valid)) current_mode = self.network.training self.network.eval() ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, patch_size=self.patch_size, regions_class_order=self.regions_class_order, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.train(current_mode) return ret def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ if debug=True then the temporary files generated for postprocessing determination will be kept """ current_mode = self.network.training self.network.eval() assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" if self.dataset_val is None: self.load_dataset() self.do_split() if segmentation_export_kwargs is None: if 'segmentation_export_params' in self.plans.keys(): force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] else: force_separate_z = None interpolation_order = 1 interpolation_order_z = 0 else: force_separate_z = segmentation_export_kwargs['force_separate_z'] interpolation_order = segmentation_export_kwargs['interpolation_order'] interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] # predictions as they come from the network go here output_folder = join(self.output_folder, validation_folder_name) maybe_mkdir_p(output_folder) # this is for debug purposes my_input_args = {'do_mirroring': do_mirroring, 'use_sliding_window': use_sliding_window, 'step_size': step_size, 'save_softmax': save_softmax, 'use_gaussian': use_gaussian, 'overwrite': overwrite, 'validation_folder_name': validation_folder_name, 'debug': debug, 'all_in_gpu': all_in_gpu, 'segmentation_export_kwargs': segmentation_export_kwargs, } save_json(my_input_args, join(output_folder, "validation_args.json")) if do_mirroring: if not self.data_aug_params['do_mirror']: raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled") mirror_axes = self.data_aug_params['mirror_axes'] else: mirror_axes = () pred_gt_tuples = [] export_pool = Pool(default_num_threads) results = [] inference_times = [] for k in self.dataset_val.keys(): properties = load_pickle(self.dataset[k]['properties_file']) fname = properties['list_of_data_files'][0].split("/")[-1][:-12] if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ (save_softmax and not isfile(join(output_folder, fname + ".npz"))): start = time.time() data = np.load(self.dataset[k]['data_file'])['data'] print(k, data.shape) data[-1][data[-1] == -1] = 0 softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1], do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, all_in_gpu=all_in_gpu, mixed_precision=self.fp16)[1] softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) end = time.time() print(f"Inference Time (k={k}): {end - start} sec.") inference_times.append(end-start) if save_softmax: softmax_fname = join(output_folder, fname + ".npz") else: softmax_fname = None """There is a problem with python process communication that prevents us from communicating obejcts larger than 2 GB between processes (basically when the length of the pickle string that will be sent is communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either filename or np.ndarray and will handle this automatically""" if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save np.save(join(output_folder, fname + ".npy"), softmax_pred) softmax_pred = join(output_folder, fname + ".npy") results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, ((softmax_pred, join(output_folder, fname + ".nii.gz"), properties, interpolation_order, self.regions_class_order, None, None, softmax_fname, None, force_separate_z, interpolation_order_z), ) ) ) pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), join(self.gt_niftis_folder, fname + ".nii.gz")]) # Print inference times print(f"Total Images: {len(inference_times)}") print(f"Total Time: {sum(inference_times)} sec.") print(f"Time per Image: {sum(inference_times) / len(inference_times)} sec.") _ = [i.get() for i in results] self.print_to_log_file("finished prediction") # evaluate raw predictions self.print_to_log_file("evaluation of raw predictions") task = self.dataset_directory.split("/")[-1] job_name = self.experiment_name _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), json_output_file=join(output_folder, "summary.json"), json_name=job_name + " val tiled %s" % (str(use_sliding_window)), json_author="Fabian", json_task=task, num_threads=default_num_threads) if run_postprocessing_on_folds: # in the old unetr_pp we would stop here. Now we add a postprocessing. This postprocessing can remove everything # except the largest connected component for each class. To see if this improves results, we do this for all # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will # have this applied during inference as well self.print_to_log_file("determining postprocessing") determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, final_subf_name=validation_folder_name + "_postprocessed", debug=debug) # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed" # They are always in that folder, even if no postprocessing as applied! # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to # be used later gt_nifti_folder = join(self.output_folder_base, "gt_niftis") maybe_mkdir_p(gt_nifti_folder) for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): success = False attempts = 0 e = None while not success and attempts < 10: try: shutil.copy(f, gt_nifti_folder) success = True except OSError as e: attempts += 1 sleep(1) if not success: print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) if e is not None: raise e self.network.train(current_mode) def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def finish_online_evaluation(self): self.online_eval_tp = np.sum(self.online_eval_tp, 0) self.online_eval_fp = np.sum(self.online_eval_fp, 0) self.online_eval_fn = np.sum(self.online_eval_fn, 0) global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)] if not np.isnan(i)] self.all_val_eval_metrics.append(np.mean(global_dc_per_class)) self.print_to_log_file("Average global foreground Dice:", str(global_dc_per_class)) self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not " "exact.)") self.online_eval_foreground_dc = [] self.online_eval_tp = [] self.online_eval_fp = [] self.online_eval_fn = [] def save_checkpoint(self, fname, save_optimizer=True): super(Trainer_tumor, self).save_checkpoint(fname, save_optimizer) info = OrderedDict() info['init'] = self.init_args info['name'] = self.__class__.__name__ info['class'] = str(self.__class__) info['plans'] = self.plans write_pickle(info, fname + ".pkl") ================================================ FILE: unetr_pp/training/network_training/network_trainer_acdc.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _warnings import warn from typing import Tuple import matplotlib from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.neural_network import SegmentationNetwork from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import _LRScheduler matplotlib.use("agg") from time import time, sleep import torch import numpy as np from torch.optim import lr_scheduler import matplotlib.pyplot as plt import sys from collections import OrderedDict import torch.backends.cudnn as cudnn from abc import abstractmethod from datetime import datetime from tqdm import trange from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda class NetworkTrainer_acdc(object): def __init__(self, deterministic=True, fp16=False): """ A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such as the training loop, tracking of training and validation losses (and the target metric if you implement it) Training can be terminated early if the validation loss (or the target metric if implemented) do not improve anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth results. What you need to override: - __init__ - initialize - run_online_evaluation (optional) - finish_online_evaluation (optional) - validate - predict_test_case """ self.fp16 = fp16 self.amp_grad_scaler = None if deterministic: np.random.seed(12345) torch.manual_seed(12345) if torch.cuda.is_available(): torch.cuda.manual_seed_all(12345) cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: cudnn.deterministic = False torch.backends.cudnn.benchmark = True ################# SET THESE IN self.initialize() ################################### self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None self.optimizer = None self.lr_scheduler = None self.tr_gen = self.val_gen = None self.was_initialized = False ################# SET THESE IN INIT ################################################ self.output_folder = None self.fold = None self.loss = None self.dataset_directory = None ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################ self.dataset = None # these can be None for inference mode self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED ##################### self.patience = 50 self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new # if this is too low then the moving average will be too noisy and the training may terminate early. If it is # too high the training will take forever self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller) self.max_num_epochs = 1000 self.num_batches_per_epoch = 250 self.num_val_batches_per_epoch = 50 self.also_val_in_tr_mode = False self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold ################# LEAVE THESE ALONE ################################################ self.val_eval_criterion_MA = None self.train_loss_MA = None self.best_val_eval_criterion_MA = None self.best_MA_tr_loss_for_patience = None self.best_epoch_based_on_MA_tr_loss = None self.all_tr_losses = [] self.all_val_losses = [] self.all_val_losses_tr_mode = [] self.all_val_eval_metrics = [] # does not have to be used self.epoch = 0 self.log_file = None self.deterministic = deterministic self.use_progress_bar = False if 'nnformer_use_progress_bar' in os.environ.keys(): self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar'])) ################# Settings for saving checkpoints ################################## self.save_every = 50 self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each # time an intermediate checkpoint is created self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA self.save_final_checkpoint = True # whether or not to save the final checkpoint @abstractmethod def initialize(self, training=True): """ create self.output_folder modify self.output_folder if you are doing cross-validation (one folder per fold) set self.tr_gen and self.val_gen call self.initialize_network and self.initialize_optimizer_and_scheduler (important!) finally set self.was_initialized to True :param training: :return: """ @abstractmethod def load_dataset(self): pass def do_split(self): """ This is a suggestion for if your dataset is a dictionary (my personal standard) :return: """ splits_file = join(self.dataset_directory, "splits_final.pkl") if not isfile(splits_file): self.print_to_log_file("Creating new split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) splits = load_pickle(splits_file) if self.fold == "all": tr_keys = val_keys = list(self.dataset.keys()) else: tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def plot_progress(self): """ Should probably by improved :return: """ try: font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr") ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False") if len(self.all_val_losses_tr_mode) > 0: ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True") if len(self.all_val_eval_metrics) == len(x_values): ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax2.set_ylabel("evaluation metric") ax.legend() ax2.legend(loc=9) fig.savefig(join(self.output_folder, "progress.png")) plt.close() ''' #below is et ..dice font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et") ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc") ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt") ax.set_xlabel("epoch") ax.set_ylabel("dice") ax.legend() fig.savefig(join(self.output_folder, "dice_progress.png")) plt.close() ''' except IOError: self.print_to_log_file("failed to plot: ", sys.exc_info()) def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): timestamp = time() dt_object = datetime.fromtimestamp(timestamp) if add_timestamp: args = ("%s:" % dt_object, *args) if self.log_file is None: maybe_mkdir_p(self.output_folder) timestamp = datetime.now() self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, timestamp.second)) with open(self.log_file, 'w') as f: f.write("Starting... \n") successful = False max_attempts = 5 ctr = 0 while not successful and ctr < max_attempts: try: with open(self.log_file, 'a+') as f: for a in args: f.write(str(a)) f.write(" ") f.write("\n") successful = True except IOError: print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) sleep(0.5) ctr += 1 if also_print_to_console: print(*args) def save_checkpoint(self, fname, save_optimizer=True): start_time = time() state_dict = self.network.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() lr_sched_state_dct = None if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): lr_sched_state_dct = self.lr_scheduler.state_dict() # WTF is this!? # for key in lr_sched_state_dct.keys(): # lr_sched_state_dct[key] = lr_sched_state_dct[key] if save_optimizer: optimizer_state_dict = self.optimizer.state_dict() else: optimizer_state_dict = None self.print_to_log_file("saving checkpoint...") save_this = { 'epoch': self.epoch + 1, 'state_dict': state_dict, 'optimizer_state_dict': optimizer_state_dict, 'lr_scheduler_state_dict': lr_sched_state_dct, 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics), 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)} if self.amp_grad_scaler is not None: save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict() torch.save(save_this, fname) self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time)) def load_best_checkpoint(self, train=True): if self.fold is None: raise RuntimeError("Cannot load best checkpoint if self.fold is None") if isfile(join(self.output_folder, "model_best.model")): self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train) else: self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling " "back to load_latest_checkpoint") self.load_latest_checkpoint(train) def load_latest_checkpoint(self, train=True): if isfile(join(self.output_folder, "model_final_checkpoint.model")): return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train) if isfile(join(self.output_folder, "model_latest.model")): return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train) if isfile(join(self.output_folder, "model_best.model")): return self.load_best_checkpoint(train) raise RuntimeError("No checkpoint found") def load_final_checkpoint(self, train=False): filename = join(self.output_folder, "model_final_checkpoint.model") if not isfile(filename): raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename) return self.load_checkpoint(filename, train=train) def load_checkpoint(self, fname, train=True): self.print_to_log_file("loading checkpoint", fname, "train=", train) if not self.was_initialized: self.initialize(train) # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device())) saved_model = torch.load(fname, map_location=torch.device('cpu')) self.load_checkpoint_ram(saved_model, train) @abstractmethod def initialize_network(self): """ initialize self.network here :return: """ pass @abstractmethod def initialize_optimizer_and_scheduler(self): """ initialize self.optimizer and self.lr_scheduler (if applicable) here :return: """ pass def load_checkpoint_ram(self, checkpoint, train=True): """ used for if the checkpoint is already in ram :param checkpoint: :param train: :return: """ print('I am here !!!') if not self.was_initialized: self.initialize(train) new_state_dict = OrderedDict() curr_state_dict_keys = list(self.network.state_dict().keys()) # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match for k, value in checkpoint['state_dict'].items(): key = k if key not in curr_state_dict_keys and key.startswith('module.'): key = key[7:] new_state_dict[key] = value if self.fp16: self._maybe_init_amp() if 'amp_grad_scaler' in checkpoint.keys(): self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler']) self.network.load_state_dict(new_state_dict) self.epoch = checkpoint['epoch'] if train: optimizer_state_dict = checkpoint['optimizer_state_dict'] if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[ 'lr_scheduler_state_dict'] is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) if issubclass(self.lr_scheduler.__class__, _LRScheduler): self.lr_scheduler.step(self.epoch) self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[ 'plot_stuff'] # load best loss (if present) print('best_stuff' in checkpoint.keys()) if 'best_stuff' in checkpoint.keys(): self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[ 'best_stuff'] # after the training is done, the epoch is incremented one more time in my old code. This results in # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here if self.epoch != len(self.all_tr_losses): self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " "due to an old bug and should only appear when you are loading old models. New " "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)") self.epoch = len(self.all_tr_losses) self.all_tr_losses = self.all_tr_losses[:self.epoch] self.all_val_losses = self.all_val_losses[:self.epoch] self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch] self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] self._maybe_init_amp() def _maybe_init_amp(self): if self.fp16 and self.amp_grad_scaler is None: self.amp_grad_scaler = GradScaler() def plot_network_architecture(self): """ can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger assumptions on the presence of class variables :return: """ pass def run_training(self): if not torch.cuda.is_available(): self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!") _ = self.tr_gen.next() _ = self.val_gen.next() if torch.cuda.is_available(): torch.cuda.empty_cache() self._maybe_init_amp() maybe_mkdir_p(self.output_folder) self.plot_network_architecture() if cudnn.benchmark and cudnn.deterministic: warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. " "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! " "If you want deterministic then set benchmark=False") if not self.was_initialized: self.initialize(True) while self.epoch < self.max_num_epochs: self.print_to_log_file("\nepoch: ", self.epoch) epoch_start_time = time() train_losses_epoch = [] # train one epoch self.network.train() if self.use_progress_bar: with trange(self.num_batches_per_epoch) as tbar: for b in tbar: tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs)) l = self.run_iteration(self.tr_gen, True) tbar.set_postfix(loss=l) train_losses_epoch.append(l) else: for _ in range(self.num_batches_per_epoch): l = self.run_iteration(self.tr_gen, True) train_losses_epoch.append(l) self.all_tr_losses.append(np.mean(train_losses_epoch)) self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1]) with torch.no_grad(): # validation with train=False self.network.eval() val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False, True) val_losses.append(l) self.all_val_losses.append(np.mean(val_losses)) self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1]) if self.also_val_in_tr_mode: self.network.train() # validation with train=True val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False) val_losses.append(l) self.all_val_losses_tr_mode.append(np.mean(val_losses)) self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1]) self.update_train_loss_MA() # needed for lr scheduler and stopping of training continue_training = self.on_epoch_end() epoch_end_time = time() if not continue_training: # allows for early stopping break self.epoch += 1 self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time)) self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint. if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model")) # now we can delete latest as it will be identical with final if isfile(join(self.output_folder, "model_latest.model")): os.remove(join(self.output_folder, "model_latest.model")) if isfile(join(self.output_folder, "model_latest.model.pkl")): os.remove(join(self.output_folder, "model_latest.model.pkl")) def maybe_update_lr(self): # maybe update learning rate if self.lr_scheduler is not None: assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler)) if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): # lr scheduler is updated with moving average val loss. should be more robust self.lr_scheduler.step(self.val_eval_criterion_MA) else: self.lr_scheduler.step(self.epoch + 1) self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) def maybe_save_checkpoint(self): """ Saves a checkpoint every save_ever epochs. :return: """ if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)): self.print_to_log_file("saving scheduled checkpoint file...") if not self.save_latest_only: self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1))) self.save_checkpoint(join(self.output_folder, "model_latest.model")) self.print_to_log_file("done") def update_eval_criterion_MA(self): """ If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping (not a minimization, but a maximization of a metric and therefore the - in the latter case) :return: """ if self.val_eval_criterion_MA is None: if len(self.all_val_eval_metrics) == 0: self.val_eval_criterion_MA = - self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.all_val_eval_metrics[-1] else: if len(self.all_val_eval_metrics) == 0: """ We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower is better, so we need to negate it. """ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_eval_metrics[-1] def manage_patience(self): # update patience continue_training = True if self.patience is not None: # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized, # initialize them if self.best_MA_tr_loss_for_patience is None: self.best_MA_tr_loss_for_patience = self.train_loss_MA if self.best_epoch_based_on_MA_tr_loss is None: self.best_epoch_based_on_MA_tr_loss = self.epoch if self.best_val_eval_criterion_MA is None: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA # check if the current epoch is the best one according to moving average of validation criterion. If so # then save 'best' model # Do not use this for validation. This is intended for test set prediction only. self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA) self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA) if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA #self.print_to_log_file("saving best epoch checkpoint...") if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model")) # Now see if the moving average of the train loss has improved. If yes then reset patience, else # increase patience if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience: self.best_MA_tr_loss_for_patience = self.train_loss_MA self.best_epoch_based_on_MA_tr_loss = self.epoch #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience) else: pass #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" % # (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps)) # if patience has reached its maximum then finish training (provided lr is low enough) if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience: if self.optimizer.param_groups[0]['lr'] > self.lr_threshold: #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)") self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2 else: #self.print_to_log_file("My patience ended") continue_training = False else: pass #self.print_to_log_file( # "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience)) return continue_training def on_epoch_end(self): self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_ # metrics self.plot_progress() self.maybe_save_checkpoint() self.update_eval_criterion_MA() self.maybe_update_lr() continue_training = self.manage_patience() return continue_training def update_train_loss_MA(self): if self.train_loss_MA is None: self.train_loss_MA = self.all_tr_losses[-1] else: self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \ self.all_tr_losses[-1] def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def run_online_evaluation(self, *args, **kwargs): """ Can be implemented, does not have to :param output_torch: :param target_npy: :return: """ pass def finish_online_evaluation(self): """ Can be implemented, does not have to :return: """ pass @abstractmethod def validate(self, *args, **kwargs): pass def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98): """ stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html :param num_iters: :param init_value: :param final_value: :param beta: :return: """ import math self._maybe_init_amp() mult = (final_value / init_value) ** (1 / num_iters) lr = init_value self.optimizer.param_groups[0]['lr'] = lr avg_loss = 0. best_loss = 0. losses = [] log_lrs = [] for batch_num in range(1, num_iters + 1): # +1 because this one here is not designed to have negative loss... loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1 # Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss smoothed_loss = avg_loss / (1 - beta ** batch_num) # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * best_loss: break # Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss # Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) # Update the lr for the next step lr *= mult self.optimizer.param_groups[0]['lr'] = lr import matplotlib.pyplot as plt lrs = [10 ** i for i in log_lrs] fig = plt.figure() plt.xscale('log') plt.plot(lrs[10:-5], losses[10:-5]) plt.savefig(join(self.output_folder, "lr_finder.png")) plt.close() return log_lrs, losses ================================================ FILE: unetr_pp/training/network_training/network_trainer_lung.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _warnings import warn from typing import Tuple import matplotlib from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.neural_network import SegmentationNetwork from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import _LRScheduler matplotlib.use("agg") from time import time, sleep import torch import numpy as np from torch.optim import lr_scheduler import matplotlib.pyplot as plt import sys from collections import OrderedDict import torch.backends.cudnn as cudnn from abc import abstractmethod from datetime import datetime from tqdm import trange from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda class NetworkTrainer_lung(object): def __init__(self, deterministic=True, fp16=False): """ A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such as the training loop, tracking of training and validation losses (and the target metric if you implement it) Training can be terminated early if the validation loss (or the target metric if implemented) do not improve anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth results. What you need to override: - __init__ - initialize - run_online_evaluation (optional) - finish_online_evaluation (optional) - validate - predict_test_case """ self.fp16 = fp16 self.amp_grad_scaler = None if deterministic: np.random.seed(12345) torch.manual_seed(12345) if torch.cuda.is_available(): torch.cuda.manual_seed_all(12345) cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: cudnn.deterministic = False torch.backends.cudnn.benchmark = True ################# SET THESE IN self.initialize() ################################### self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None self.optimizer = None self.lr_scheduler = None self.tr_gen = self.val_gen = None self.was_initialized = False ################# SET THESE IN INIT ################################################ self.output_folder = None self.fold = None self.loss = None self.dataset_directory = None ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################ self.dataset = None # these can be None for inference mode self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED ##################### self.patience = 50 self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new # if this is too low then the moving average will be too noisy and the training may terminate early. If it is # too high the training will take forever self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller) self.max_num_epochs = 1000 self.num_batches_per_epoch = 250 self.num_val_batches_per_epoch = 50 self.also_val_in_tr_mode = False self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold ################# LEAVE THESE ALONE ################################################ self.val_eval_criterion_MA = None self.train_loss_MA = None self.best_val_eval_criterion_MA = None self.best_MA_tr_loss_for_patience = None self.best_epoch_based_on_MA_tr_loss = None self.all_tr_losses = [] self.all_val_losses = [] self.all_val_losses_tr_mode = [] self.all_val_eval_metrics = [] # does not have to be used self.epoch = 0 self.log_file = None self.deterministic = deterministic self.use_progress_bar = False if 'nnformer_use_progress_bar' in os.environ.keys(): self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar'])) ################# Settings for saving checkpoints ################################## self.save_every = 50 self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each # time an intermediate checkpoint is created self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA self.save_final_checkpoint = True # whether or not to save the final checkpoint @abstractmethod def initialize(self, training=True): """ create self.output_folder modify self.output_folder if you are doing cross-validation (one folder per fold) set self.tr_gen and self.val_gen call self.initialize_network and self.initialize_optimizer_and_scheduler (important!) finally set self.was_initialized to True :param training: :return: """ @abstractmethod def load_dataset(self): pass def do_split(self): """ This is a suggestion for if your dataset is a dictionary (my personal standard) :return: """ splits_file = join(self.dataset_directory, "splits_final.pkl") if not isfile(splits_file): self.print_to_log_file("Creating new split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) splits = load_pickle(splits_file) if self.fold == "all": tr_keys = val_keys = list(self.dataset.keys()) else: tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def plot_progress(self): """ Should probably by improved :return: """ try: font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr") ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False") if len(self.all_val_losses_tr_mode) > 0: ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True") if len(self.all_val_eval_metrics) == len(x_values): ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax2.set_ylabel("evaluation metric") ax.legend() ax2.legend(loc=9) fig.savefig(join(self.output_folder, "progress.png")) plt.close() ''' #below is et ..dice font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et") ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc") ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt") ax.set_xlabel("epoch") ax.set_ylabel("dice") ax.legend() fig.savefig(join(self.output_folder, "dice_progress.png")) plt.close() ''' except IOError: self.print_to_log_file("failed to plot: ", sys.exc_info()) def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): timestamp = time() dt_object = datetime.fromtimestamp(timestamp) if add_timestamp: args = ("%s:" % dt_object, *args) if self.log_file is None: maybe_mkdir_p(self.output_folder) timestamp = datetime.now() self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, timestamp.second)) with open(self.log_file, 'w') as f: f.write("Starting... \n") successful = False max_attempts = 5 ctr = 0 while not successful and ctr < max_attempts: try: with open(self.log_file, 'a+') as f: for a in args: f.write(str(a)) f.write(" ") f.write("\n") successful = True except IOError: print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) sleep(0.5) ctr += 1 if also_print_to_console: print(*args) def save_checkpoint(self, fname, save_optimizer=True): start_time = time() state_dict = self.network.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() lr_sched_state_dct = None if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): lr_sched_state_dct = self.lr_scheduler.state_dict() # WTF is this!? # for key in lr_sched_state_dct.keys(): # lr_sched_state_dct[key] = lr_sched_state_dct[key] if save_optimizer: optimizer_state_dict = self.optimizer.state_dict() else: optimizer_state_dict = None self.print_to_log_file("saving checkpoint...") save_this = { 'epoch': self.epoch + 1, 'state_dict': state_dict, 'optimizer_state_dict': optimizer_state_dict, 'lr_scheduler_state_dict': lr_sched_state_dct, 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics), 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)} if self.amp_grad_scaler is not None: save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict() torch.save(save_this, fname) self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time)) def load_best_checkpoint(self, train=True): if self.fold is None: raise RuntimeError("Cannot load best checkpoint if self.fold is None") if isfile(join(self.output_folder, "model_best.model")): self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train) else: self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling " "back to load_latest_checkpoint") self.load_latest_checkpoint(train) def load_latest_checkpoint(self, train=True): if isfile(join(self.output_folder, "model_final_checkpoint.model")): return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train) if isfile(join(self.output_folder, "model_latest.model")): return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train) if isfile(join(self.output_folder, "model_best.model")): return self.load_best_checkpoint(train) raise RuntimeError("No checkpoint found") def load_final_checkpoint(self, train=False): filename = join(self.output_folder, "model_final_checkpoint.model") if not isfile(filename): raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename) return self.load_checkpoint(filename, train=train) def load_checkpoint(self, fname, train=True): self.print_to_log_file("loading checkpoint", fname, "train=", train) if not self.was_initialized: self.initialize(train) # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device())) saved_model = torch.load(fname, map_location=torch.device('cpu')) self.load_checkpoint_ram(saved_model, train) @abstractmethod def initialize_network(self): """ initialize self.network here :return: """ pass @abstractmethod def initialize_optimizer_and_scheduler(self): """ initialize self.optimizer and self.lr_scheduler (if applicable) here :return: """ pass def load_checkpoint_ram(self, checkpoint, train=True): """ used for if the checkpoint is already in ram :param checkpoint: :param train: :return: """ print('I am here !!!') if not self.was_initialized: self.initialize(train) new_state_dict = OrderedDict() curr_state_dict_keys = list(self.network.state_dict().keys()) # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match for k, value in checkpoint['state_dict'].items(): key = k if key not in curr_state_dict_keys and key.startswith('module.'): key = key[7:] new_state_dict[key] = value if self.fp16: self._maybe_init_amp() if 'amp_grad_scaler' in checkpoint.keys(): self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler']) self.network.load_state_dict(new_state_dict) self.epoch = checkpoint['epoch'] if train: optimizer_state_dict = checkpoint['optimizer_state_dict'] if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[ 'lr_scheduler_state_dict'] is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) if issubclass(self.lr_scheduler.__class__, _LRScheduler): self.lr_scheduler.step(self.epoch) self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[ 'plot_stuff'] # load best loss (if present) print('best_stuff' in checkpoint.keys()) if 'best_stuff' in checkpoint.keys(): self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[ 'best_stuff'] # after the training is done, the epoch is incremented one more time in my old code. This results in # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here if self.epoch != len(self.all_tr_losses): self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " "due to an old bug and should only appear when you are loading old models. New " "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)") self.epoch = len(self.all_tr_losses) self.all_tr_losses = self.all_tr_losses[:self.epoch] self.all_val_losses = self.all_val_losses[:self.epoch] self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch] self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] self._maybe_init_amp() def _maybe_init_amp(self): if self.fp16 and self.amp_grad_scaler is None: self.amp_grad_scaler = GradScaler() def plot_network_architecture(self): """ can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger assumptions on the presence of class variables :return: """ pass def run_training(self): if not torch.cuda.is_available(): self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!") _ = self.tr_gen.next() _ = self.val_gen.next() if torch.cuda.is_available(): torch.cuda.empty_cache() self._maybe_init_amp() maybe_mkdir_p(self.output_folder) self.plot_network_architecture() if cudnn.benchmark and cudnn.deterministic: warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. " "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! " "If you want deterministic then set benchmark=False") if not self.was_initialized: self.initialize(True) while self.epoch < self.max_num_epochs: self.print_to_log_file("\nepoch: ", self.epoch) epoch_start_time = time() train_losses_epoch = [] # train one epoch self.network.train() if self.use_progress_bar: with trange(self.num_batches_per_epoch) as tbar: for b in tbar: tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs)) l = self.run_iteration(self.tr_gen, True) tbar.set_postfix(loss=l) train_losses_epoch.append(l) else: for _ in range(self.num_batches_per_epoch): l = self.run_iteration(self.tr_gen, True) train_losses_epoch.append(l) self.all_tr_losses.append(np.mean(train_losses_epoch)) self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1]) with torch.no_grad(): # validation with train=False self.network.eval() val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False, True) val_losses.append(l) self.all_val_losses.append(np.mean(val_losses)) self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1]) if self.also_val_in_tr_mode: self.network.train() # validation with train=True val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False) val_losses.append(l) self.all_val_losses_tr_mode.append(np.mean(val_losses)) self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1]) self.update_train_loss_MA() # needed for lr scheduler and stopping of training continue_training = self.on_epoch_end() epoch_end_time = time() if not continue_training: # allows for early stopping break self.epoch += 1 self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time)) self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint. if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model")) # now we can delete latest as it will be identical with final if isfile(join(self.output_folder, "model_latest.model")): os.remove(join(self.output_folder, "model_latest.model")) if isfile(join(self.output_folder, "model_latest.model.pkl")): os.remove(join(self.output_folder, "model_latest.model.pkl")) def maybe_update_lr(self): # maybe update learning rate if self.lr_scheduler is not None: assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler)) if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): # lr scheduler is updated with moving average val loss. should be more robust self.lr_scheduler.step(self.val_eval_criterion_MA) else: self.lr_scheduler.step(self.epoch + 1) self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) def maybe_save_checkpoint(self): """ Saves a checkpoint every save_ever epochs. :return: """ if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)): self.print_to_log_file("saving scheduled checkpoint file...") if not self.save_latest_only: self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1))) self.save_checkpoint(join(self.output_folder, "model_latest.model")) self.print_to_log_file("done") def update_eval_criterion_MA(self): """ If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping (not a minimization, but a maximization of a metric and therefore the - in the latter case) :return: """ if self.val_eval_criterion_MA is None: if len(self.all_val_eval_metrics) == 0: self.val_eval_criterion_MA = - self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.all_val_eval_metrics[-1] else: if len(self.all_val_eval_metrics) == 0: """ We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower is better, so we need to negate it. """ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_eval_metrics[-1] def manage_patience(self): # update patience continue_training = True if self.patience is not None: # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized, # initialize them if self.best_MA_tr_loss_for_patience is None: self.best_MA_tr_loss_for_patience = self.train_loss_MA if self.best_epoch_based_on_MA_tr_loss is None: self.best_epoch_based_on_MA_tr_loss = self.epoch if self.best_val_eval_criterion_MA is None: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA # check if the current epoch is the best one according to moving average of validation criterion. If so # then save 'best' model # Do not use this for validation. This is intended for test set prediction only. self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA) self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA) if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA #self.print_to_log_file("saving best epoch checkpoint...") if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model")) # Now see if the moving average of the train loss has improved. If yes then reset patience, else # increase patience if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience: self.best_MA_tr_loss_for_patience = self.train_loss_MA self.best_epoch_based_on_MA_tr_loss = self.epoch #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience) else: pass #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" % # (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps)) # if patience has reached its maximum then finish training (provided lr is low enough) if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience: if self.optimizer.param_groups[0]['lr'] > self.lr_threshold: #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)") self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2 else: #self.print_to_log_file("My patience ended") continue_training = False else: pass #self.print_to_log_file( # "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience)) return continue_training def on_epoch_end(self): self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_ # metrics self.plot_progress() self.maybe_save_checkpoint() self.update_eval_criterion_MA() self.maybe_update_lr() continue_training = self.manage_patience() return continue_training def update_train_loss_MA(self): if self.train_loss_MA is None: self.train_loss_MA = self.all_tr_losses[-1] else: self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \ self.all_tr_losses[-1] def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def run_online_evaluation(self, *args, **kwargs): """ Can be implemented, does not have to :param output_torch: :param target_npy: :return: """ pass def finish_online_evaluation(self): """ Can be implemented, does not have to :return: """ pass @abstractmethod def validate(self, *args, **kwargs): pass def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98): """ stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html :param num_iters: :param init_value: :param final_value: :param beta: :return: """ import math self._maybe_init_amp() mult = (final_value / init_value) ** (1 / num_iters) lr = init_value self.optimizer.param_groups[0]['lr'] = lr avg_loss = 0. best_loss = 0. losses = [] log_lrs = [] for batch_num in range(1, num_iters + 1): # +1 because this one here is not designed to have negative loss... loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1 # Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss smoothed_loss = avg_loss / (1 - beta ** batch_num) # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * best_loss: break # Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss # Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) # Update the lr for the next step lr *= mult self.optimizer.param_groups[0]['lr'] = lr import matplotlib.pyplot as plt lrs = [10 ** i for i in log_lrs] fig = plt.figure() plt.xscale('log') plt.plot(lrs[10:-5], losses[10:-5]) plt.savefig(join(self.output_folder, "lr_finder.png")) plt.close() return log_lrs, losses ================================================ FILE: unetr_pp/training/network_training/network_trainer_synapse.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _warnings import warn from typing import Tuple import matplotlib from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.neural_network import SegmentationNetwork from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import _LRScheduler from tqdm import tqdm matplotlib.use("agg") from time import time, sleep import torch import numpy as np from torch.optim import lr_scheduler import matplotlib.pyplot as plt import sys from collections import OrderedDict import torch.backends.cudnn as cudnn from abc import abstractmethod from datetime import datetime from tqdm import trange from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda class NetworkTrainer_synapse(object): def __init__(self, deterministic=True, fp16=False): """ A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such as the training loop, tracking of training and validation losses (and the target metric if you implement it) Training can be terminated early if the validation loss (or the target metric if implemented) do not improve anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth results. What you need to override: - __init__ - initialize - run_online_evaluation (optional) - finish_online_evaluation (optional) - validate - predict_test_case """ self.fp16 = fp16 self.amp_grad_scaler = None if deterministic: np.random.seed(12345) torch.manual_seed(12345) if torch.cuda.is_available(): torch.cuda.manual_seed_all(12345) cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: cudnn.deterministic = False torch.backends.cudnn.benchmark = True ################# SET THESE IN self.initialize() ################################### self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None self.optimizer = None self.lr_scheduler = None self.tr_gen = self.val_gen = None self.was_initialized = False ################# SET THESE IN INIT ################################################ self.output_folder = None self.fold = None self.loss = None self.dataset_directory = None ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################ self.dataset = None # these can be None for inference mode self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED ##################### self.patience = 50 self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new # if this is too low then the moving average will be too noisy and the training may terminate early. If it is # too high the training will take forever self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller) self.max_num_epochs = 1000 self.num_batches_per_epoch = 250 self.num_val_batches_per_epoch = 50 self.also_val_in_tr_mode = False self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold ################# LEAVE THESE ALONE ################################################ self.val_eval_criterion_MA = None self.train_loss_MA = None self.best_val_eval_criterion_MA = None self.best_MA_tr_loss_for_patience = None self.best_epoch_based_on_MA_tr_loss = None self.all_tr_losses = [] self.all_val_losses = [] self.all_val_losses_tr_mode = [] self.all_val_eval_metrics = [] # does not have to be used self.epoch = 0 self.log_file = None self.deterministic = deterministic self.use_progress_bar = False if 'nnformer_use_progress_bar' in os.environ.keys(): self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar'])) ################# Settings for saving checkpoints ################################## self.save_every = 50 self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each # time an intermediate checkpoint is created self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA self.save_final_checkpoint = True # whether or not to save the final checkpoint @abstractmethod def initialize(self, training=True): """ create self.output_folder modify self.output_folder if you are doing cross-validation (one folder per fold) set self.tr_gen and self.val_gen call self.initialize_network and self.initialize_optimizer_and_scheduler (important!) finally set self.was_initialized to True :param training: :return: """ @abstractmethod def load_dataset(self): pass def do_split(self): """ This is a suggestion for if your dataset is a dictionary (my personal standard) :return: """ splits_file = join(self.dataset_directory, "splits_final.pkl") if not isfile(splits_file): self.print_to_log_file("Creating new split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) splits = load_pickle(splits_file) if self.fold == "all": tr_keys = val_keys = list(self.dataset.keys()) else: tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def plot_progress(self): """ Should probably by improved :return: """ try: font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr") ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False") if len(self.all_val_losses_tr_mode) > 0: ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True") if len(self.all_val_eval_metrics) == len(x_values): ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax2.set_ylabel("evaluation metric") ax.legend() ax2.legend(loc=9) fig.savefig(join(self.output_folder, "progress.png")) plt.close() ''' #below is et ..dice font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et") ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc") ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt") ax.set_xlabel("epoch") ax.set_ylabel("dice") ax.legend() fig.savefig(join(self.output_folder, "dice_progress.png")) plt.close() ''' except IOError: self.print_to_log_file("failed to plot: ", sys.exc_info()) def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): timestamp = time() dt_object = datetime.fromtimestamp(timestamp) if add_timestamp: args = ("%s:" % dt_object, *args) if self.log_file is None: maybe_mkdir_p(self.output_folder) timestamp = datetime.now() self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, timestamp.second)) with open(self.log_file, 'w') as f: f.write("Starting... \n") successful = False max_attempts = 5 ctr = 0 while not successful and ctr < max_attempts: try: with open(self.log_file, 'a+') as f: for a in args: f.write(str(a)) f.write(" ") f.write("\n") successful = True except IOError: print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) sleep(0.5) ctr += 1 if also_print_to_console: print(*args) def save_checkpoint(self, fname, save_optimizer=True): start_time = time() state_dict = self.network.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() lr_sched_state_dct = None if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): lr_sched_state_dct = self.lr_scheduler.state_dict() # WTF is this!? # for key in lr_sched_state_dct.keys(): # lr_sched_state_dct[key] = lr_sched_state_dct[key] if save_optimizer: optimizer_state_dict = self.optimizer.state_dict() else: optimizer_state_dict = None self.print_to_log_file("saving checkpoint...") save_this = { 'epoch': self.epoch + 1, 'state_dict': state_dict, 'optimizer_state_dict': optimizer_state_dict, 'lr_scheduler_state_dict': lr_sched_state_dct, 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics), 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)} if self.amp_grad_scaler is not None: save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict() torch.save(save_this, fname) self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time)) def load_best_checkpoint(self, train=True): if self.fold is None: raise RuntimeError("Cannot load best checkpoint if self.fold is None") if isfile(join(self.output_folder, "model_best.model")): self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train) else: self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling " "back to load_latest_checkpoint") self.load_latest_checkpoint(train) def load_latest_checkpoint(self, train=True): if isfile(join(self.output_folder, "model_final_checkpoint.model")): return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train) if isfile(join(self.output_folder, "model_latest.model")): return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train) if isfile(join(self.output_folder, "model_best.model")): return self.load_best_checkpoint(train) raise RuntimeError("No checkpoint found") def load_final_checkpoint(self, train=False): filename = join(self.output_folder, "model_final_checkpoint.model") if not isfile(filename): raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename) return self.load_checkpoint(filename, train=train) def load_checkpoint(self, fname, train=True): self.print_to_log_file("loading checkpoint", fname, "train=", train) if not self.was_initialized: self.initialize(train) # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device())) saved_model = torch.load(fname, map_location=torch.device('cpu')) self.load_checkpoint_ram(saved_model, train) @abstractmethod def initialize_network(self): """ initialize self.network here :return: """ pass @abstractmethod def initialize_optimizer_and_scheduler(self): """ initialize self.optimizer and self.lr_scheduler (if applicable) here :return: """ pass def load_checkpoint_ram(self, checkpoint, train=True): """ used for if the checkpoint is already in ram :param checkpoint: :param train: :return: """ print('I am here !!!') if not self.was_initialized: self.initialize(train) new_state_dict = OrderedDict() curr_state_dict_keys = list(self.network.state_dict().keys()) # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match for k, value in checkpoint['state_dict'].items(): key = k if key not in curr_state_dict_keys and key.startswith('module.'): key = key[7:] new_state_dict[key] = value if self.fp16: self._maybe_init_amp() if 'amp_grad_scaler' in checkpoint.keys(): self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler']) self.network.load_state_dict(new_state_dict) self.epoch = checkpoint['epoch'] if train: optimizer_state_dict = checkpoint['optimizer_state_dict'] if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[ 'lr_scheduler_state_dict'] is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) if issubclass(self.lr_scheduler.__class__, _LRScheduler): self.lr_scheduler.step(self.epoch) self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[ 'plot_stuff'] # load best loss (if present) print('best_stuff' in checkpoint.keys()) if 'best_stuff' in checkpoint.keys(): self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[ 'best_stuff'] # after the training is done, the epoch is incremented one more time in my old code. This results in # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here if self.epoch != len(self.all_tr_losses): self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " "due to an old bug and should only appear when you are loading old models. New " "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)") self.epoch = len(self.all_tr_losses) self.all_tr_losses = self.all_tr_losses[:self.epoch] self.all_val_losses = self.all_val_losses[:self.epoch] self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch] self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] self._maybe_init_amp() def _maybe_init_amp(self): if self.fp16 and self.amp_grad_scaler is None: self.amp_grad_scaler = GradScaler() def plot_network_architecture(self): """ can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger assumptions on the presence of class variables :return: """ pass def run_training(self): if not torch.cuda.is_available(): self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!") _ = self.tr_gen.next() _ = self.val_gen.next() if torch.cuda.is_available(): torch.cuda.empty_cache() self._maybe_init_amp() maybe_mkdir_p(self.output_folder) self.plot_network_architecture() if cudnn.benchmark and cudnn.deterministic: warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. " "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! " "If you want deterministic then set benchmark=False") if not self.was_initialized: self.initialize(True) while self.epoch < self.max_num_epochs: self.print_to_log_file("\nepoch: ", self.epoch) epoch_start_time = time() train_losses_epoch = [] # train one epoch self.network.train() if self.use_progress_bar: with trange(self.num_batches_per_epoch) as tbar: for b in tbar: tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs)) l = self.run_iteration(self.tr_gen, True) tbar.set_postfix(loss=l) train_losses_epoch.append(l) else: for _ in range(self.num_batches_per_epoch): l = self.run_iteration(self.tr_gen, True) train_losses_epoch.append(l) self.all_tr_losses.append(np.mean(train_losses_epoch)) self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1]) with torch.no_grad(): # validation with train=False self.network.eval() val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False, True) val_losses.append(l) self.all_val_losses.append(np.mean(val_losses)) self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1]) if self.also_val_in_tr_mode: self.network.train() # validation with train=True val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False) val_losses.append(l) self.all_val_losses_tr_mode.append(np.mean(val_losses)) self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1]) self.update_train_loss_MA() # needed for lr scheduler and stopping of training continue_training = self.on_epoch_end() epoch_end_time = time() if not continue_training: # allows for early stopping break self.epoch += 1 self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time)) self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint. if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model")) # now we can delete latest as it will be identical with final if isfile(join(self.output_folder, "model_latest.model")): os.remove(join(self.output_folder, "model_latest.model")) if isfile(join(self.output_folder, "model_latest.model.pkl")): os.remove(join(self.output_folder, "model_latest.model.pkl")) def maybe_update_lr(self): # maybe update learning rate if self.lr_scheduler is not None: assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler)) if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): # lr scheduler is updated with moving average val loss. should be more robust self.lr_scheduler.step(self.val_eval_criterion_MA) else: self.lr_scheduler.step(self.epoch + 1) self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) def maybe_save_checkpoint(self): """ Saves a checkpoint every save_ever epochs. :return: """ if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)): self.print_to_log_file("saving scheduled checkpoint file...") if not self.save_latest_only: self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1))) self.save_checkpoint(join(self.output_folder, "model_latest.model")) self.print_to_log_file("done") def update_eval_criterion_MA(self): """ If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping (not a minimization, but a maximization of a metric and therefore the - in the latter case) :return: """ if self.val_eval_criterion_MA is None: if len(self.all_val_eval_metrics) == 0: self.val_eval_criterion_MA = - self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.all_val_eval_metrics[-1] else: if len(self.all_val_eval_metrics) == 0: """ We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower is better, so we need to negate it. """ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_eval_metrics[-1] def manage_patience(self): # update patience continue_training = True if self.patience is not None: # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized, # initialize them if self.best_MA_tr_loss_for_patience is None: self.best_MA_tr_loss_for_patience = self.train_loss_MA if self.best_epoch_based_on_MA_tr_loss is None: self.best_epoch_based_on_MA_tr_loss = self.epoch if self.best_val_eval_criterion_MA is None: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA # check if the current epoch is the best one according to moving average of validation criterion. If so # then save 'best' model # Do not use this for validation. This is intended for test set prediction only. self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA) self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA) if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA #self.print_to_log_file("saving best epoch checkpoint...") if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model")) # Now see if the moving average of the train loss has improved. If yes then reset patience, else # increase patience if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience: self.best_MA_tr_loss_for_patience = self.train_loss_MA self.best_epoch_based_on_MA_tr_loss = self.epoch #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience) else: pass #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" % # (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps)) # if patience has reached its maximum then finish training (provided lr is low enough) if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience: if self.optimizer.param_groups[0]['lr'] > self.lr_threshold: #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)") self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2 else: #self.print_to_log_file("My patience ended") continue_training = False else: pass #self.print_to_log_file( # "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience)) return continue_training def on_epoch_end(self): self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_ # metrics self.plot_progress() self.maybe_save_checkpoint() self.update_eval_criterion_MA() self.maybe_update_lr() continue_training = self.manage_patience() return continue_training def update_train_loss_MA(self): if self.train_loss_MA is None: self.train_loss_MA = self.all_tr_losses[-1] else: self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \ self.all_tr_losses[-1] def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def run_online_evaluation(self, *args, **kwargs): """ Can be implemented, does not have to :param output_torch: :param target_npy: :return: """ pass def finish_online_evaluation(self): """ Can be implemented, does not have to :return: """ pass @abstractmethod def validate(self, *args, **kwargs): pass def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98): """ stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html :param num_iters: :param init_value: :param final_value: :param beta: :return: """ import math self._maybe_init_amp() mult = (final_value / init_value) ** (1 / num_iters) lr = init_value self.optimizer.param_groups[0]['lr'] = lr avg_loss = 0. best_loss = 0. losses = [] log_lrs = [] for batch_num in tqdm(range(1, num_iters + 1)): # +1 because this one here is not designed to have negative loss... loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False) + 1 # Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss smoothed_loss = avg_loss / (1 - beta ** batch_num) # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * best_loss: break # Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss # Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) # Update the lr for the next step lr *= mult self.optimizer.param_groups[0]['lr'] = lr import matplotlib.pyplot as plt lrs = [10 ** i for i in log_lrs] fig = plt.figure() plt.xscale('log') plt.plot(lrs[10:-5], losses[10:-5]) plt.savefig(join(self.output_folder, "lr_finder.png")) plt.close() import numpy as np index = np.argmin(losses) print(f"The best LR is {lrs[index]}") return log_lrs, losses ================================================ FILE: unetr_pp/training/network_training/network_trainer_tumor.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _warnings import warn from typing import Tuple import matplotlib from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.neural_network import SegmentationNetwork from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import _LRScheduler matplotlib.use("agg") from time import time, sleep import torch import numpy as np from torch.optim import lr_scheduler import matplotlib.pyplot as plt import sys from collections import OrderedDict import torch.backends.cudnn as cudnn from abc import abstractmethod from datetime import datetime from tqdm import trange from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda class NetworkTrainer_tumor(object): def __init__(self, deterministic=True, fp16=False): """ A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such as the training loop, tracking of training and validation losses (and the target metric if you implement it) Training can be terminated early if the validation loss (or the target metric if implemented) do not improve anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth results. What you need to override: - __init__ - initialize - run_online_evaluation (optional) - finish_online_evaluation (optional) - validate - predict_test_case """ self.fp16 = fp16 self.amp_grad_scaler = None if deterministic: np.random.seed(12345) torch.manual_seed(12345) if torch.cuda.is_available(): torch.cuda.manual_seed_all(12345) cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: cudnn.deterministic = False torch.backends.cudnn.benchmark = True ################# SET THESE IN self.initialize() ################################### self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None self.optimizer = None self.lr_scheduler = None self.tr_gen = self.val_gen = None self.was_initialized = False ################# SET THESE IN INIT ################################################ self.output_folder = None self.fold = None self.loss = None self.dataset_directory = None ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################ self.dataset = None # these can be None for inference mode self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED ##################### self.patience = 50 self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new # if this is too low then the moving average will be too noisy and the training may terminate early. If it is # too high the training will take forever self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller) self.max_num_epochs = 1000 self.num_batches_per_epoch = 250 self.num_val_batches_per_epoch = 50 self.also_val_in_tr_mode = False self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold ################# LEAVE THESE ALONE ################################################ self.val_eval_criterion_MA = None self.train_loss_MA = None self.best_val_eval_criterion_MA = None self.best_MA_tr_loss_for_patience = None self.best_epoch_based_on_MA_tr_loss = None self.all_tr_losses = [] self.all_val_losses = [] self.all_val_losses_tr_mode = [] self.all_val_eval_metrics = [] # does not have to be used self.epoch = 0 self.log_file = None self.deterministic = deterministic self.use_progress_bar = False if 'nnformer_use_progress_bar' in os.environ.keys(): self.use_progress_bar = bool(int(os.environ['nnformer_use_progress_bar'])) ################# Settings for saving checkpoints ################################## self.save_every = 50 self.save_latest_only = False # if false it will not store/overwrite _latest but separate files each # time an intermediate checkpoint is created self.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA self.save_final_checkpoint = True # whether or not to save the final checkpoint @abstractmethod def initialize(self, training=True): """ create self.output_folder modify self.output_folder if you are doing cross-validation (one folder per fold) set self.tr_gen and self.val_gen call self.initialize_network and self.initialize_optimizer_and_scheduler (important!) finally set self.was_initialized to True :param training: :return: """ @abstractmethod def load_dataset(self): pass def do_split(self): """ This is a suggestion for if your dataset is a dictionary (my personal standard) :return: """ splits_file = join(self.dataset_directory, "splits_final.pkl") if not isfile(splits_file): self.print_to_log_file("Creating new split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) splits = load_pickle(splits_file) if self.fold == "all": tr_keys = val_keys = list(self.dataset.keys()) else: tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def plot_progress(self): """ Should probably by improved :return: """ try: font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr") ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False") if len(self.all_val_losses_tr_mode) > 0: ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True") if len(self.all_val_eval_metrics) == len(x_values): ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax2.set_ylabel("evaluation metric") ax.legend() ax2.legend(loc=9) fig.savefig(join(self.output_folder, "progress.png")) plt.close() ''' #below is et ..dice font = {'weight': 'normal', 'size': 18} matplotlib.rc('font', **font) fig = plt.figure(figsize=(30, 24)) ax = fig.add_subplot(111) ax2 = ax.twinx() x_values = list(range(self.epoch + 1)) ax.plot(x_values, self.plot_et, color='b', ls='-', label="dice_et") ax.plot(x_values, self.plot_tc, color='g', ls='-', label="dice_tc") ax.plot(x_values, self.plot_wt, color='r', ls='-', label="dice_wt") ax.set_xlabel("epoch") ax.set_ylabel("dice") ax.legend() fig.savefig(join(self.output_folder, "dice_progress.png")) plt.close() ''' except IOError: self.print_to_log_file("failed to plot: ", sys.exc_info()) def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): timestamp = time() dt_object = datetime.fromtimestamp(timestamp) if add_timestamp: args = ("%s:" % dt_object, *args) if self.log_file is None: maybe_mkdir_p(self.output_folder) timestamp = datetime.now() self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, timestamp.second)) with open(self.log_file, 'w') as f: f.write("Starting... \n") successful = False max_attempts = 5 ctr = 0 while not successful and ctr < max_attempts: try: with open(self.log_file, 'a+') as f: for a in args: f.write(str(a)) f.write(" ") f.write("\n") successful = True except IOError: print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) sleep(0.5) ctr += 1 if also_print_to_console: print(*args) def save_checkpoint(self, fname, save_optimizer=True): start_time = time() state_dict = self.network.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() lr_sched_state_dct = None if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): lr_sched_state_dct = self.lr_scheduler.state_dict() # WTF is this!? # for key in lr_sched_state_dct.keys(): # lr_sched_state_dct[key] = lr_sched_state_dct[key] if save_optimizer: optimizer_state_dict = self.optimizer.state_dict() else: optimizer_state_dict = None self.print_to_log_file("saving checkpoint...") save_this = { 'epoch': self.epoch + 1, 'state_dict': state_dict, 'optimizer_state_dict': optimizer_state_dict, 'lr_scheduler_state_dict': lr_sched_state_dct, 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics), 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)} if self.amp_grad_scaler is not None: save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict() torch.save(save_this, fname) self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time)) def load_best_checkpoint(self, train=True): if self.fold is None: raise RuntimeError("Cannot load best checkpoint if self.fold is None") if isfile(join(self.output_folder, "model_best.model")): self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train) else: self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling " "back to load_latest_checkpoint") self.load_latest_checkpoint(train) def load_latest_checkpoint(self, train=True): if isfile(join(self.output_folder, "model_final_checkpoint.model")): return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train) if isfile(join(self.output_folder, "model_latest.model")): return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train) if isfile(join(self.output_folder, "model_best.model")): return self.load_best_checkpoint(train) raise RuntimeError("No checkpoint found") def load_final_checkpoint(self, train=False): filename = join(self.output_folder, "model_final_checkpoint.model") if not isfile(filename): raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename) return self.load_checkpoint(filename, train=train) def load_checkpoint(self, fname, train=True): self.print_to_log_file("loading checkpoint", fname, "train=", train) if not self.was_initialized: self.initialize(train) # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device())) saved_model = torch.load(fname, map_location=torch.device('cpu')) self.load_checkpoint_ram(saved_model, train) @abstractmethod def initialize_network(self): """ initialize self.network here :return: """ pass @abstractmethod def initialize_optimizer_and_scheduler(self): """ initialize self.optimizer and self.lr_scheduler (if applicable) here :return: """ pass def load_checkpoint_ram(self, checkpoint, train=True): """ used for if the checkpoint is already in ram :param checkpoint: :param train: :return: """ print('I am here !!!') if not self.was_initialized: self.initialize(train) new_state_dict = OrderedDict() curr_state_dict_keys = list(self.network.state_dict().keys()) # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match for k, value in checkpoint['state_dict'].items(): key = k if key not in curr_state_dict_keys and key.startswith('module.'): key = key[7:] new_state_dict[key] = value if self.fp16: self._maybe_init_amp() if 'amp_grad_scaler' in checkpoint.keys(): self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler']) self.network.load_state_dict(new_state_dict) self.epoch = checkpoint['epoch'] if train: optimizer_state_dict = checkpoint['optimizer_state_dict'] if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[ 'lr_scheduler_state_dict'] is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) if issubclass(self.lr_scheduler.__class__, _LRScheduler): self.lr_scheduler.step(self.epoch) self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[ 'plot_stuff'] # load best loss (if present) print('best_stuff' in checkpoint.keys()) if 'best_stuff' in checkpoint.keys(): self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[ 'best_stuff'] # after the training is done, the epoch is incremented one more time in my old code. This results in # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here if self.epoch != len(self.all_tr_losses): self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " "due to an old bug and should only appear when you are loading old models. New " "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)") self.epoch = len(self.all_tr_losses) self.all_tr_losses = self.all_tr_losses[:self.epoch] self.all_val_losses = self.all_val_losses[:self.epoch] self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch] self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] self._maybe_init_amp() def _maybe_init_amp(self): if self.fp16 and self.amp_grad_scaler is None: self.amp_grad_scaler = GradScaler() def plot_network_architecture(self): """ can be implemented (see nnFormerTrainer) but does not have to. Not implemented here because it imposes stronger assumptions on the presence of class variables :return: """ pass def run_training(self): if not torch.cuda.is_available(): self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!") _ = self.tr_gen.next() _ = self.val_gen.next() if torch.cuda.is_available(): torch.cuda.empty_cache() self._maybe_init_amp() maybe_mkdir_p(self.output_folder) self.plot_network_architecture() if cudnn.benchmark and cudnn.deterministic: warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. " "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! " "If you want deterministic then set benchmark=False") if not self.was_initialized: self.initialize(True) while self.epoch < self.max_num_epochs: self.print_to_log_file("\nepoch: ", self.epoch) epoch_start_time = time() train_losses_epoch = [] # train one epoch self.network.train() if self.use_progress_bar: with trange(self.num_batches_per_epoch) as tbar: for b in tbar: tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs)) l = self.run_iteration(self.tr_gen, True) tbar.set_postfix(loss=l) train_losses_epoch.append(l) else: for _ in range(self.num_batches_per_epoch): l = self.run_iteration(self.tr_gen, True) train_losses_epoch.append(l) self.all_tr_losses.append(np.mean(train_losses_epoch)) self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1]) with torch.no_grad(): # validation with train=False self.network.eval() val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False, True) val_losses.append(l) self.all_val_losses.append(np.mean(val_losses)) self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1]) if self.also_val_in_tr_mode: self.network.train() # validation with train=True val_losses = [] for b in range(self.num_val_batches_per_epoch): l = self.run_iteration(self.val_gen, False) val_losses.append(l) self.all_val_losses_tr_mode.append(np.mean(val_losses)) self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1]) self.update_train_loss_MA() # needed for lr scheduler and stopping of training continue_training = self.on_epoch_end() epoch_end_time = time() if not continue_training: # allows for early stopping break self.epoch += 1 self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time)) self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint. if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model")) # now we can delete latest as it will be identical with final if isfile(join(self.output_folder, "model_latest.model")): os.remove(join(self.output_folder, "model_latest.model")) if isfile(join(self.output_folder, "model_latest.model.pkl")): os.remove(join(self.output_folder, "model_latest.model.pkl")) def maybe_update_lr(self): # maybe update learning rate if self.lr_scheduler is not None: assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler)) if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau): # lr scheduler is updated with moving average val loss. should be more robust self.lr_scheduler.step(self.val_eval_criterion_MA) else: self.lr_scheduler.step(self.epoch + 1) self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) def maybe_save_checkpoint(self): """ Saves a checkpoint every save_ever epochs. :return: """ if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)): self.print_to_log_file("saving scheduled checkpoint file...") if not self.save_latest_only: self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1))) self.save_checkpoint(join(self.output_folder, "model_latest.model")) self.print_to_log_file("done") def update_eval_criterion_MA(self): """ If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping (not a minimization, but a maximization of a metric and therefore the - in the latter case) :return: """ if self.val_eval_criterion_MA is None: if len(self.all_val_eval_metrics) == 0: self.val_eval_criterion_MA = - self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.all_val_eval_metrics[-1] else: if len(self.all_val_eval_metrics) == 0: """ We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower is better, so we need to negate it. """ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_losses[-1] else: self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + ( 1 - self.val_eval_criterion_alpha) * \ self.all_val_eval_metrics[-1] def manage_patience(self): # update patience continue_training = True if self.patience is not None: # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized, # initialize them if self.best_MA_tr_loss_for_patience is None: self.best_MA_tr_loss_for_patience = self.train_loss_MA if self.best_epoch_based_on_MA_tr_loss is None: self.best_epoch_based_on_MA_tr_loss = self.epoch if self.best_val_eval_criterion_MA is None: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA # check if the current epoch is the best one according to moving average of validation criterion. If so # then save 'best' model # Do not use this for validation. This is intended for test set prediction only. self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA) self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA) if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA: self.best_val_eval_criterion_MA = self.val_eval_criterion_MA #self.print_to_log_file("saving best epoch checkpoint...") if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model")) # Now see if the moving average of the train loss has improved. If yes then reset patience, else # increase patience if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience: self.best_MA_tr_loss_for_patience = self.train_loss_MA self.best_epoch_based_on_MA_tr_loss = self.epoch #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience) else: pass #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" % # (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps)) # if patience has reached its maximum then finish training (provided lr is low enough) if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience: if self.optimizer.param_groups[0]['lr'] > self.lr_threshold: #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)") self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2 else: #self.print_to_log_file("My patience ended") continue_training = False else: pass #self.print_to_log_file( # "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience)) return continue_training def on_epoch_end(self): self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_ # metrics self.plot_progress() self.maybe_save_checkpoint() self.update_eval_criterion_MA() self.maybe_update_lr() continue_training = self.manage_patience() return continue_training def update_train_loss_MA(self): if self.train_loss_MA is None: self.train_loss_MA = self.all_tr_losses[-1] else: self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \ self.all_tr_losses[-1] def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def run_online_evaluation(self, *args, **kwargs): """ Can be implemented, does not have to :param output_torch: :param target_npy: :return: """ pass def finish_online_evaluation(self): """ Can be implemented, does not have to :return: """ pass @abstractmethod def validate(self, *args, **kwargs): pass def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98): """ stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html :param num_iters: :param init_value: :param final_value: :param beta: :return: """ import math self._maybe_init_amp() mult = (final_value / init_value) ** (1 / num_iters) lr = init_value self.optimizer.param_groups[0]['lr'] = lr avg_loss = 0. best_loss = 0. losses = [] log_lrs = [] for batch_num in range(1, num_iters + 1): # +1 because this one here is not designed to have negative loss... loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1 # Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss smoothed_loss = avg_loss / (1 - beta ** batch_num) # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * best_loss: break # Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss # Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) # Update the lr for the next step lr *= mult self.optimizer.param_groups[0]['lr'] = lr import matplotlib.pyplot as plt lrs = [10 ** i for i in log_lrs] fig = plt.figure() plt.xscale('log') plt.plot(lrs[10:-5], losses[10:-5]) plt.savefig(join(self.output_folder, "lr_finder.png")) plt.close() return log_lrs, losses ================================================ FILE: unetr_pp/training/network_training/unetr_pp_trainer_acdc.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import Tuple import numpy as np import torch from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2 from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \ get_patch_size, default_3D_augmentation_params from unetr_pp.training.dataloading.dataset_loading import unpack_dataset from unetr_pp.training.network_training.Trainer_acdc import Trainer_acdc from unetr_pp.utilities.nd_softmax import softmax_helper from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import autocast from unetr_pp.training.learning_rate.poly_lr import poly_lr from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.acdc.unetr_pp_acdc import UNETR_PP from fvcore.nn import FlopCountAnalysis class unetr_pp_trainer_acdc(Trainer_acdc): """ Info for Fabian: same as internal nnUNetTrainerV2_2 """ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.max_num_epochs = 1000 self.initial_lr = 1e-2 self.deep_supervision_scales = None self.ds_loss_weights = None self.pin_memory = True self.load_pretrain_weight = False self.load_plans_file() if len(self.plans['plans_per_stage']) == 2: Stage = 1 else: Stage = 0 self.crop_size = self.plans['plans_per_stage'][Stage]['patch_size'] self.input_channels = self.plans['num_modalities'] self.num_classes = self.plans['num_classes'] + 1 self.conv_op = nn.Conv3d self.embedding_dim = 96 self.depths = [2, 2, 2, 2] self.num_heads = [3, 6, 12, 24] self.embedding_patch_size = [1, 4, 4] self.window_size = [[3, 5, 5], [3, 5, 5], [7, 10, 10], [3, 5, 5]] self.down_stride = [[1, 4, 4], [1, 8, 8], [2, 16, 16], [4, 32, 32]] self.deep_supervision = True def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - enforce to only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.plans['plans_per_stage'][0]['patch_size'] = np.array([16, 160, 160]) self.crop_size = np.array([16, 160, 160]) self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[1, 4, 4], [2, 2, 2], [2, 2, 2]] self.process_plans(self.plans) self.setup_DA_params() if self.deep_supervision: ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 # mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) # weights[~mask] = 0 weights = weights / weights.sum() print(weights) self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None, pin_memory=self.pin_memory, use_nondetMultiThreadedAugmenter=False, seeds_train=seeds_train, seeds_val=seeds_val ) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def initialize_network(self): """ - momentum 0.99 - SGD instead of Adam - self.lr_scheduler = None because we do poly_lr - deep supervision = True - i am sure I forgot something here Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though :return: """ self.network = UNETR_PP(in_channels=self.input_channels, out_channels=self.num_classes, feature_size=16, num_heads=4, depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True, ) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper # Print the network parameters & Flops n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad) input_res = (1, 16, 160, 160) input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype, device=next(self.network.parameters()).device) flops = FlopCountAnalysis(self.network, input) model_flops = flops.total() print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M") print(f"MAdds: {round(model_flops * 1e-9, 2)} G") def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_online_evaluation(self, output, target): """ due to deep supervision the return value and the reference are now lists of tensors. We only need the full resolution output because this is what we are interested in in the end. The others are ignored :param output: :param target: :return: """ if self.deep_supervision: target = target[0] output = output[0] else: target = target output = output return super().run_online_evaluation(output, target) def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs, run_postprocessing_on_folds=run_postprocessing_on_folds) self.network.do_ds = ds return ret def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision=True) -> Tuple[ np.ndarray, np.ndarray]: """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().predict_preprocessed_data_return_seg_and_softmax(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): """ gradient clipping improves training stability :param data_generator: :param do_backprop: :param run_online_evaluation: :return: """ data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, so always the same) and save it as splits_final.pkl file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. :return: """ if self.fold == "all": # if fold==all then we use all images for training and validation tr_keys = val_keys = list(self.dataset.keys()) else: splits_file = join(self.dataset_directory, "splits_final.pkl") # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) else: self.print_to_log_file("Using splits from existing split file:", splits_file) splits = load_pickle(splits_file) self.print_to_log_file("The split file contains %d splits." % len(splits)) self.print_to_log_file("Desired fold for training: %d" % self.fold) splits[self.fold]['train'] = np.array(['patient001_frame01', 'patient001_frame12', 'patient004_frame01', 'patient004_frame15', 'patient005_frame01', 'patient005_frame13', 'patient006_frame01', 'patient006_frame16', 'patient007_frame01', 'patient007_frame07', 'patient010_frame01', 'patient010_frame13', 'patient011_frame01', 'patient011_frame08', 'patient013_frame01', 'patient013_frame14', 'patient015_frame01', 'patient015_frame10', 'patient016_frame01', 'patient016_frame12', 'patient018_frame01', 'patient018_frame10', 'patient019_frame01', 'patient019_frame11', 'patient020_frame01', 'patient020_frame11', 'patient021_frame01', 'patient021_frame13', 'patient022_frame01', 'patient022_frame11', 'patient023_frame01', 'patient023_frame09', 'patient025_frame01', 'patient025_frame09', 'patient026_frame01', 'patient026_frame12', 'patient027_frame01', 'patient027_frame11', 'patient028_frame01', 'patient028_frame09', 'patient029_frame01', 'patient029_frame12', 'patient030_frame01', 'patient030_frame12', 'patient031_frame01', 'patient031_frame10', 'patient032_frame01', 'patient032_frame12', 'patient033_frame01', 'patient033_frame14', 'patient034_frame01', 'patient034_frame16', 'patient035_frame01', 'patient035_frame11', 'patient036_frame01', 'patient036_frame12', 'patient037_frame01', 'patient037_frame12', 'patient038_frame01', 'patient038_frame11', 'patient039_frame01', 'patient039_frame10', 'patient040_frame01', 'patient040_frame13', 'patient041_frame01', 'patient041_frame11', 'patient043_frame01', 'patient043_frame07', 'patient044_frame01', 'patient044_frame11', 'patient045_frame01', 'patient045_frame13', 'patient046_frame01', 'patient046_frame10', 'patient047_frame01', 'patient047_frame09', 'patient050_frame01', 'patient050_frame12', 'patient051_frame01', 'patient051_frame11', 'patient052_frame01', 'patient052_frame09', 'patient054_frame01', 'patient054_frame12', 'patient056_frame01', 'patient056_frame12', 'patient057_frame01', 'patient057_frame09', 'patient058_frame01', 'patient058_frame14', 'patient059_frame01', 'patient059_frame09', 'patient060_frame01', 'patient060_frame14', 'patient061_frame01', 'patient061_frame10', 'patient062_frame01', 'patient062_frame09', 'patient063_frame01', 'patient063_frame16', 'patient065_frame01', 'patient065_frame14', 'patient066_frame01', 'patient066_frame11', 'patient068_frame01', 'patient068_frame12', 'patient069_frame01', 'patient069_frame12', 'patient070_frame01', 'patient070_frame10', 'patient071_frame01', 'patient071_frame09', 'patient072_frame01', 'patient072_frame11', 'patient073_frame01', 'patient073_frame10', 'patient074_frame01', 'patient074_frame12', 'patient075_frame01', 'patient075_frame06', 'patient076_frame01', 'patient076_frame12', 'patient077_frame01', 'patient077_frame09', 'patient078_frame01', 'patient078_frame09', 'patient080_frame01', 'patient080_frame10', 'patient082_frame01', 'patient082_frame07', 'patient083_frame01', 'patient083_frame08', 'patient084_frame01', 'patient084_frame10', 'patient085_frame01', 'patient085_frame09', 'patient086_frame01', 'patient086_frame08', 'patient087_frame01', 'patient087_frame10', 'patient089_frame01', 'patient089_frame10', 'patient090_frame04', 'patient090_frame11', 'patient091_frame01', 'patient091_frame09', 'patient093_frame01', 'patient093_frame14', 'patient094_frame01', 'patient094_frame07', 'patient096_frame01', 'patient096_frame08', 'patient097_frame01', 'patient097_frame11', 'patient098_frame01', 'patient098_frame09', 'patient099_frame01', 'patient099_frame09', 'patient100_frame01', 'patient100_frame13']) splits[self.fold]['val'] = np.array(['patient002_frame01', 'patient002_frame12', 'patient003_frame01','patient003_frame15', 'patient008_frame01', 'patient008_frame13', 'patient009_frame01', 'patient009_frame13', 'patient012_frame01', 'patient012_frame13', 'patient014_frame01', 'patient014_frame13', 'patient017_frame01', 'patient017_frame09', 'patient024_frame01', 'patient024_frame09', 'patient042_frame01', 'patient042_frame16', 'patient048_frame01', 'patient048_frame08', 'patient049_frame01', 'patient049_frame11', 'patient053_frame01', 'patient053_frame12', 'patient055_frame01', 'patient055_frame10', 'patient064_frame01', 'patient064_frame12', 'patient067_frame01', 'patient067_frame10', 'patient079_frame01', 'patient079_frame11', 'patient081_frame01', 'patient081_frame07', 'patient088_frame01', 'patient088_frame12', 'patient092_frame01', 'patient092_frame06', 'patient095_frame01', 'patient095_frame12']) if self.fold < len(splits): tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] self.print_to_log_file("This split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) else: self.print_to_log_file("INFO: You requested fold %d for training but splits " "contain only %d folds. I am now creating a " "random (but seeded) 80:20 split!" % (self.fold, len(splits))) # if we request a fold that is not in the split file, create a random 80:20 split rnd = np.random.RandomState(seed=12345 + self.fold) keys = np.sort(list(self.dataset.keys())) idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) idx_val = [i for i in range(len(keys)) if i not in idx_tr] tr_keys = [keys[i] for i in idx_tr] val_keys = [keys[i] for i in idx_val] self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def setup_DA_params(self): """ - we increase roation angle from [-15, 15] to [-30, 30] - scale range is now (0.7, 1.4), was (0.85, 1.25) - we don't do elastic deformation anymore :return: """ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] if self.threeD: self.data_aug_params = default_3D_augmentation_params self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params["scale_range"] = (0.7, 1.4) self.data_aug_params["do_elastic"] = False self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform self.data_aug_params["num_cached_per_thread"] = 2 def maybe_update_lr(self, epoch=None): """ if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1 (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented. herefore we need to do +1 here) :param epoch: :return: """ if epoch is None: ep = self.epoch + 1 else: ep = epoch self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9) self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6)) def on_epoch_end(self): """ overwrite patient-based early stopping. Always run to 1000 epochs :return: """ super().on_epoch_end() continue_training = self.epoch < self.max_num_epochs # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95 if self.epoch == 100: if self.all_val_eval_metrics[-1] == 0: self.optimizer.param_groups[0]["momentum"] = 0.95 self.network.apply(InitWeights_He(1e-2)) self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too " "high momentum. High momentum (0.99) is good for datasets where it works, but " "sometimes causes issues such as this one. Momentum has now been reduced to " "0.95 and network weights have been reinitialized") return continue_training def run_training(self): """ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first continued epoch with self.initial_lr we also need to make sure deep supervision in the network is enabled for training, thus the wrapper :return: """ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we # want at the start of the training ds = self.network.do_ds if self.deep_supervision: self.network.do_ds = True else: self.network.do_ds = False ret = super().run_training() self.network.do_ds = ds return ret ================================================ FILE: unetr_pp/training/network_training/unetr_pp_trainer_lung.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import Tuple import numpy as np import torch from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2 from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \ get_patch_size, default_3D_augmentation_params from unetr_pp.training.dataloading.dataset_loading import unpack_dataset from unetr_pp.training.network_training.Trainer_lung import Trainer_lung from unetr_pp.utilities.nd_softmax import softmax_helper from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import autocast from unetr_pp.training.learning_rate.poly_lr import poly_lr from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.network_architecture.lung.unetr_pp_lung import UNETR_PP from fvcore.nn import FlopCountAnalysis class unetr_pp_trainer_lung(Trainer_lung): """ Info for Fabian: same as internal nnUNetTrainerV2_2 """ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.max_num_epochs = 1000 self.initial_lr = 1e-2 self.deep_supervision_scales = None self.ds_loss_weights = None self.pin_memory = True self.load_pretrain_weight = False self.load_plans_file() if len(self.plans['plans_per_stage']) == 2: Stage = 1 else: Stage = 0 self.crop_size = (32,192,192) #self.plans['plans_per_stage'][Stage]['patch_size'] self.input_channels = self.plans['num_modalities'] self.num_classes = self.plans['num_classes'] + 1 #print("#############classes", self.num_classes) self.conv_op = nn.Conv3d self.embedding_dim = 96 self.depths = [2, 2, 2, 2] self.num_heads = [3, 6, 12, 24] self.embedding_patch_size = [1, 4, 4] self.window_size = [[3, 5, 5], [3, 5, 5], [7, 10, 10], [3, 5, 5]] self.down_stride = [[1, 4, 4], [1, 8, 8], [2, 16, 16], [4, 32, 32]] self.deep_supervision = True def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - enforce to only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.plans['plans_per_stage'][0]['patch_size'] = np.array([32, 192, 192]) self.crop_size = np.array([32, 192, 192]) self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[1, 4, 4], [2, 2, 2], [2, 2, 2]] self.process_plans(self.plans) self.setup_DA_params() if self.deep_supervision: ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 # mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) # weights[~mask] = 0 weights = weights / weights.sum() print(weights) self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None, pin_memory=self.pin_memory, use_nondetMultiThreadedAugmenter=False, seeds_train=seeds_train, seeds_val=seeds_val ) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def initialize_network(self): """ - momentum 0.99 - SGD instead of Adam - self.lr_scheduler = None because we do poly_lr - deep supervision = True - i am sure I forgot something here Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though :return: """ self.network = UNETR_PP(in_channels=self.input_channels, out_channels=self.num_classes, feature_size=16, num_heads=4, depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True, ) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper # Print the network parameters & Flops n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad) input_res = (1, 32, 192, 192) input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype, device=next(self.network.parameters()).device) flops = FlopCountAnalysis(self.network, input) model_flops = flops.total() print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M") print(f"MAdds: {round(model_flops * 1e-9, 2)} G") def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_online_evaluation(self, output, target): """ due to deep supervision the return value and the reference are now lists of tensors. We only need the full resolution output because this is what we are interested in in the end. The others are ignored :param output: :param target: :return: """ if self.deep_supervision: target = target[0] output = output[0] else: target = target output = output return super().run_online_evaluation(output, target) def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs, run_postprocessing_on_folds=run_postprocessing_on_folds) self.network.do_ds = ds return ret def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision=True) -> Tuple[ np.ndarray, np.ndarray]: """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().predict_preprocessed_data_return_seg_and_softmax(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): """ gradient clipping improves training stability :param data_generator: :param do_backprop: :param run_online_evaluation: :return: """ data_dict = next(data_generator) data = data_dict['data'] #print("****data:",data.shape) target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, so always the same) and save it as splits_final.pkl file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. :return: """ if self.fold == "all": # if fold==all then we use all images for training and validation tr_keys = val_keys = list(self.dataset.keys()) else: splits_file = join(self.dataset_directory, "splits_final.pkl") # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) else: self.print_to_log_file("Using splits from existing split file:", splits_file) splits = load_pickle(splits_file) self.print_to_log_file("The split file contains %d splits." % len(splits)) self.print_to_log_file("Desired fold for training: %d" % self.fold) splits[self.fold]['train'] = np.array([ 'lung_053', 'lung_022', 'lung_041', 'lung_069', 'lung_014', 'lung_006', 'lung_065', 'lung_018', 'lung_096', 'lung_084', 'lung_086', 'lung_043', 'lung_020', 'lung_051', 'lung_079', 'lung_004', 'lung_075', 'lung_016', 'lung_071', 'lung_028', 'lung_055', 'lung_036', 'lung_047', 'lung_059', 'lung_061', 'lung_010', 'lung_073', 'lung_026', 'lung_038', 'lung_045', 'lung_034', 'lung_049', 'lung_057', 'lung_080', 'lung_092', 'lung_015', 'lung_064', 'lung_031', 'lung_023', 'lung_005', 'lung_078', 'lung_066', 'lung_009', 'lung_074', 'lung_042', 'lung_033', 'lung_095', 'lung_037', 'lung_054', 'lung_029' ]) splits[self.fold]['val'] = np.array([ 'lung_058', 'lung_025', 'lung_046', 'lung_070', 'lung_001', 'lung_062', 'lung_083', 'lung_081', 'lung_093', 'lung_044', 'lung_027', 'lung_048', 'lung_003' ]) if self.fold < len(splits): tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] self.print_to_log_file("This split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) else: self.print_to_log_file("INFO: You requested fold %d for training but splits " "contain only %d folds. I am now creating a " "random (but seeded) 80:20 split!" % (self.fold, len(splits))) # if we request a fold that is not in the split file, create a random 80:20 split rnd = np.random.RandomState(seed=12345 + self.fold) keys = np.sort(list(self.dataset.keys())) idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) idx_val = [i for i in range(len(keys)) if i not in idx_tr] tr_keys = [keys[i] for i in idx_tr] val_keys = [keys[i] for i in idx_val] self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def setup_DA_params(self): """ - we increase roation angle from [-15, 15] to [-30, 30] - scale range is now (0.7, 1.4), was (0.85, 1.25) - we don't do elastic deformation anymore :return: """ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] if self.threeD: self.data_aug_params = default_3D_augmentation_params self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params["scale_range"] = (0.7, 1.4) self.data_aug_params["do_elastic"] = False self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform self.data_aug_params["num_cached_per_thread"] = 2 def maybe_update_lr(self, epoch=None): """ if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1 (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented. herefore we need to do +1 here) :param epoch: :return: """ if epoch is None: ep = self.epoch + 1 else: ep = epoch self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9) self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6)) def on_epoch_end(self): """ overwrite patient-based early stopping. Always run to 1000 epochs :return: """ super().on_epoch_end() continue_training = self.epoch < self.max_num_epochs # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95 if self.epoch == 100: if self.all_val_eval_metrics[-1] == 0: self.optimizer.param_groups[0]["momentum"] = 0.95 self.network.apply(InitWeights_He(1e-2)) self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too " "high momentum. High momentum (0.99) is good for datasets where it works, but " "sometimes causes issues such as this one. Momentum has now been reduced to " "0.95 and network weights have been reinitialized") return continue_training def run_training(self): """ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first continued epoch with self.initial_lr we also need to make sure deep supervision in the network is enabled for training, thus the wrapper :return: """ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we # want at the start of the training ds = self.network.do_ds if self.deep_supervision: self.network.do_ds = True else: self.network.do_ds = False ret = super().run_training() self.network.do_ds = ds return ret ================================================ FILE: unetr_pp/training/network_training/unetr_pp_trainer_synapse.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import Tuple import numpy as np import torch from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2 from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda from unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \ get_patch_size, default_3D_augmentation_params from unetr_pp.training.dataloading.dataset_loading import unpack_dataset from unetr_pp.training.network_training.Trainer_synapse import Trainer_synapse from unetr_pp.utilities.nd_softmax import softmax_helper from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import autocast from unetr_pp.training.learning_rate.poly_lr import poly_lr from batchgenerators.utilities.file_and_folder_operations import * from fvcore.nn import FlopCountAnalysis class unetr_pp_trainer_synapse(Trainer_synapse): """ same as internal nnFromerTrinerV2 and nnUNetTrainerV2_2 """ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.max_num_epochs = 1000 self.initial_lr = 1e-2 self.deep_supervision_scales = None self.ds_loss_weights = None self.pin_memory = True self.load_pretrain_weight = False self.load_plans_file() self.crop_size = [64, 128, 128] self.input_channels = self.plans['num_modalities'] self.num_classes = self.plans['num_classes'] + 1 self.conv_op = nn.Conv3d self.embedding_dim = 192 self.depths = [2, 2, 2, 2] self.num_heads = [6, 12, 24, 48] self.embedding_patch_size = [2, 4, 4] self.window_size = [4, 4, 8, 4] self.deep_supervision = True def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - enforce to only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[2, 4, 4], [2, 2, 2], [2, 2, 2]] self.process_plans(self.plans) self.setup_DA_params() if self.deep_supervision: ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 # mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) # weights[~mask] = 0 weights = weights / weights.sum() print(weights) self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None, pin_memory=self.pin_memory, use_nondetMultiThreadedAugmenter=False, seeds_train=seeds_train, seeds_val=seeds_val ) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def initialize_network(self): """ - momentum 0.99 - SGD instead of Adam - self.lr_scheduler = None because we do poly_lr - deep supervision = True - i am sure I forgot something here Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though :return: """ self.network = UNETR_PP(in_channels=self.input_channels, out_channels=self.num_classes, img_size=self.crop_size, feature_size=16, num_heads=4, depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True, ) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper # Print the network parameters & Flops n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad) input_res = (1, 64, 128, 128) input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype, device=next(self.network.parameters()).device) flops = FlopCountAnalysis(self.network, input) model_flops = flops.total() print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M") print(f"MAdds: {round(model_flops * 1e-9, 2)} G") def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_online_evaluation(self, output, target): """ due to deep supervision the return value and the reference are now lists of tensors. We only need the full resolution output because this is what we are interested in in the end. The others are ignored :param output: :param target: :return: """ if self.deep_supervision: target = target[0] output = output[0] else: target = target output = output return super().run_online_evaluation(output, target) def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs, run_postprocessing_on_folds=run_postprocessing_on_folds) self.network.do_ds = ds return ret def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision=True) -> Tuple[ np.ndarray, np.ndarray]: """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().predict_preprocessed_data_return_seg_and_softmax(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): """ gradient clipping improves training stability :param data_generator: :param do_backprop: :param run_online_evaluation: :return: """ data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, so always the same) and save it as splits_final.pkl file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. :return: """ if self.fold == "all": # if fold==all then we use all images for training and validation tr_keys = val_keys = list(self.dataset.keys()) else: splits_file = join(self.dataset_directory, "splits_final.pkl") # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) else: self.print_to_log_file("Using splits from existing split file:", splits_file) splits = load_pickle(splits_file) self.print_to_log_file("The split file contains %d splits." % len(splits)) self.print_to_log_file("Desired fold for training: %d" % self.fold) splits[self.fold]['train'] = np.array( ['img0006', 'img0007', 'img0009', 'img0010', 'img0021', 'img0023', 'img0024', 'img0026', 'img0027', 'img0031', 'img0033', 'img0034' \ , 'img0039', 'img0040', 'img0005', 'img0028', 'img0030', 'img0037']) splits[self.fold]['val'] = np.array( ['img0001', 'img0002', 'img0003', 'img0004', 'img0008', 'img0022', 'img0025', 'img0029', 'img0032', 'img0035', 'img0036', 'img0038']) if self.fold < len(splits): tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] self.print_to_log_file("This split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) else: self.print_to_log_file("INFO: You requested fold %d for training but splits " "contain only %d folds. I am now creating a " "random (but seeded) 80:20 split!" % (self.fold, len(splits))) # if we request a fold that is not in the split file, create a random 80:20 split rnd = np.random.RandomState(seed=12345 + self.fold) keys = np.sort(list(self.dataset.keys())) idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) idx_val = [i for i in range(len(keys)) if i not in idx_tr] tr_keys = [keys[i] for i in idx_tr] val_keys = [keys[i] for i in idx_val] self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def setup_DA_params(self): """ - we increase roation angle from [-15, 15] to [-30, 30] - scale range is now (0.7, 1.4), was (0.85, 1.25) - we don't do elastic deformation anymore :return: """ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] if self.threeD: self.data_aug_params = default_3D_augmentation_params self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params["scale_range"] = (0.7, 1.4) self.data_aug_params["do_elastic"] = False self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform self.data_aug_params["num_cached_per_thread"] = 2 def maybe_update_lr(self, epoch=None): """ if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1 (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented. herefore we need to do +1 here) :param epoch: :return: """ if epoch is None: ep = self.epoch + 1 else: ep = epoch self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9) self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6)) def on_epoch_end(self): """ overwrite patient-based early stopping. Always run to 1000 epochs :return: """ super().on_epoch_end() continue_training = self.epoch < self.max_num_epochs # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95 if self.epoch == 100: if self.all_val_eval_metrics[-1] == 0: self.optimizer.param_groups[0]["momentum"] = 0.95 self.network.apply(InitWeights_He(1e-2)) self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too " "high momentum. High momentum (0.99) is good for datasets where it works, but " "sometimes causes issues such as this one. Momentum has now been reduced to " "0.95 and network weights have been reinitialized") return continue_training def run_training(self): """ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first continued epoch with self.initial_lr we also need to make sure deep supervision in the network is enabled for training, thus the wrapper :return: """ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we # want at the start of the training ds = self.network.do_ds if self.deep_supervision: self.network.do_ds = True else: self.network.do_ds = False ret = super().run_training() self.network.do_ds = ds return ret ================================================ FILE: unetr_pp/training/network_training/unetr_pp_trainer_tumor.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from typing import Tuple import numpy as np import torch from unetr_pp.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation from unetr_pp.training.loss_functions.deep_supervision import MultipleOutputLoss2 from unetr_pp.utilities.to_torch import maybe_to_torch, to_cuda from unetr_pp.network_architecture.tumor.unetr_pp_tumor import UNETR_PP from unetr_pp.network_architecture.initialization import InitWeights_He from unetr_pp.network_architecture.neural_network import SegmentationNetwork from unetr_pp.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \ get_patch_size, default_3D_augmentation_params from unetr_pp.training.dataloading.dataset_loading import unpack_dataset from unetr_pp.training.network_training.Trainer_tumor import Trainer_tumor from unetr_pp.utilities.nd_softmax import softmax_helper from sklearn.model_selection import KFold from torch import nn from torch.cuda.amp import autocast from unetr_pp.training.learning_rate.poly_lr import poly_lr from batchgenerators.utilities.file_and_folder_operations import * from fvcore.nn import FlopCountAnalysis class unetr_pp_trainer_tumor(Trainer_tumor): """ same as internal nnFromerTrinerV2 and nnUNetTrainerV2_2 """ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.max_num_epochs = 1000 self.initial_lr = 1e-2 self.deep_supervision_scales = None self.ds_loss_weights = None self.pin_memory = True self.load_pretrain_weight = False self.load_plans_file() if len(self.plans['plans_per_stage'])==2: Stage=1 else: Stage=0 self.crop_size = self.plans['plans_per_stage'][Stage]['patch_size'] print("self.crop_size in init train:",self.crop_size) self.input_channels = self.plans['num_modalities'] #print("SELF.INPUT CHANNELS",self.input_channels) self.num_classes = self.plans['num_classes'] + 1 #print("SELF.NUM_CLASSES",self.num_classes) self.conv_op = nn.Conv3d self.embedding_dim = 96 self.depths = [2, 2, 2, 2] self.num_heads = [3, 6, 12, 24] self.embedding_patch_size = [4, 4, 4] self.window_size = [4, 4, 8, 4] self.deep_supervision = True def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - enforce to only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.plans['plans_per_stage'][self.stage]['pool_op_kernel_sizes'] = [[4, 4, 4], [2, 2, 2], [2, 2, 2]] self.process_plans(self.plans) self.setup_DA_params() if self.deep_supervision: ################# Here we wrap the loss for deep supervision ############ # we need to know the number of outputs of the network net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 # mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) # weights[~mask] = 0 weights = weights / weights.sum() print(weights) self.ds_loss_weights = weights # now wrap the loss self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights) ################# END ################### self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales if self.deep_supervision else None, pin_memory=self.pin_memory, use_nondetMultiThreadedAugmenter=False, seeds_train=seeds_train, seeds_val=seeds_val ) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel)) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def initialize_network(self): """ - momentum 0.99 - SGD instead of Adam - self.lr_scheduler = None because we do poly_lr - deep supervision = True - i am sure I forgot something here Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though :return: """ self.network = UNETR_PP(in_channels=self.input_channels, out_channels=self.num_classes, feature_size=16, num_heads=4, depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True, ) #print("self.input_channels", self.input_channels) #print("self.num_classes", self.num_classes) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper # Print the network parameters & Flops n_parameters = sum(p.numel() for p in self.network.parameters() if p.requires_grad) input_res = (4, 128, 128, 128) input = torch.ones(()).new_empty((1, *input_res), dtype=next(self.network.parameters()).dtype, device=next(self.network.parameters()).device) flops = FlopCountAnalysis(self.network, input) model_flops = flops.total() print(f"Total trainable parameters: {round(n_parameters * 1e-6, 2)} M") print(f"MAdds: {round(model_flops * 1e-9, 2)} G") def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_online_evaluation(self, output, target): """ due to deep supervision the return value and the reference are now lists of tensors. We only need the full resolution output because this is what we are interested in in the end. The others are ignored :param output: :param target: :return: """ if self.deep_supervision: target = target[0] output = output[0] else: target = target output = output return super().run_online_evaluation(output, target) def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs, run_postprocessing_on_folds=run_postprocessing_on_folds) self.network.do_ds = ds return ret def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, mirror_axes: Tuple[int] = None, use_sliding_window: bool = True, step_size: float = 0.5, use_gaussian: bool = True, pad_border_mode: str = 'constant', pad_kwargs: dict = None, all_in_gpu: bool = False, verbose: bool = True, mixed_precision=True) -> Tuple[ np.ndarray, np.ndarray]: """ We need to wrap this because we need to enforce self.network.do_ds = False for prediction """ ds = self.network.do_ds self.network.do_ds = False ret = super().predict_preprocessed_data_return_seg_and_softmax(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, use_sliding_window=use_sliding_window, step_size=step_size, use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, mixed_precision=mixed_precision) self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): """ gradient clipping improves training stability :param data_generator: :param do_backprop: :param run_online_evaluation: :return: """ data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): output = self.network(data) del data l = self.loss(output, target) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: output = self.network(data) del data l = self.loss(output, target) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() if run_online_evaluation: self.run_online_evaluation(output, target) del target return l.detach().cpu().numpy() def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, so always the same) and save it as splits_final.pkl file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. :return: """ if self.fold == "all": # if fold==all then we use all images for training and validation tr_keys = val_keys = list(self.dataset.keys()) else: splits_file = join(self.dataset_directory, "splits_final.pkl") # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") splits = [] all_keys_sorted = np.sort(list(self.dataset.keys())) kfold = KFold(n_splits=5, shuffle=True, random_state=12345) for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): train_keys = np.array(all_keys_sorted)[train_idx] test_keys = np.array(all_keys_sorted)[test_idx] splits.append(OrderedDict()) splits[-1]['train'] = train_keys splits[-1]['val'] = test_keys save_pickle(splits, splits_file) else: self.print_to_log_file("Using splits from existing split file:", splits_file) splits = load_pickle(splits_file) self.print_to_log_file("The split file contains %d splits." % len(splits)) self.print_to_log_file("Desired fold for training: %d" % self.fold) splits[self.fold]['train'] = np.array([ 'BRATS_001', 'BRATS_002', 'BRATS_003', 'BRATS_004', 'BRATS_005', 'BRATS_006', 'BRATS_007', 'BRATS_008', 'BRATS_009', 'BRATS_010', 'BRATS_013', 'BRATS_014', 'BRATS_015', 'BRATS_016', 'BRATS_017', 'BRATS_019', 'BRATS_022', 'BRATS_023', 'BRATS_024', 'BRATS_025', 'BRATS_026', 'BRATS_027', 'BRATS_030', 'BRATS_031', 'BRATS_033', 'BRATS_035', 'BRATS_037', 'BRATS_038', 'BRATS_039', 'BRATS_040', 'BRATS_042', 'BRATS_043', 'BRATS_044', 'BRATS_045', 'BRATS_046', 'BRATS_048', 'BRATS_050', 'BRATS_051', 'BRATS_052', 'BRATS_054', 'BRATS_055', 'BRATS_060', 'BRATS_061', 'BRATS_062', 'BRATS_063', 'BRATS_064', 'BRATS_065', 'BRATS_066', 'BRATS_067', 'BRATS_068', 'BRATS_070', 'BRATS_072', 'BRATS_073', 'BRATS_074', 'BRATS_075', 'BRATS_078', 'BRATS_079', 'BRATS_080', 'BRATS_081', 'BRATS_082', 'BRATS_083', 'BRATS_084', 'BRATS_085', 'BRATS_086', 'BRATS_087', 'BRATS_088', 'BRATS_091', 'BRATS_093', 'BRATS_094', 'BRATS_096', 'BRATS_097', 'BRATS_098', 'BRATS_100', 'BRATS_101', 'BRATS_102', 'BRATS_104', 'BRATS_108', 'BRATS_110', 'BRATS_111', 'BRATS_112', 'BRATS_115', 'BRATS_116', 'BRATS_117', 'BRATS_119', 'BRATS_120', 'BRATS_121', 'BRATS_122', 'BRATS_123', 'BRATS_125', 'BRATS_126', 'BRATS_127', 'BRATS_128', 'BRATS_129', 'BRATS_130', 'BRATS_131', 'BRATS_132', 'BRATS_133', 'BRATS_134', 'BRATS_135', 'BRATS_136', 'BRATS_137', 'BRATS_138', 'BRATS_140', 'BRATS_141', 'BRATS_142', 'BRATS_143', 'BRATS_144', 'BRATS_146', 'BRATS_148', 'BRATS_149', 'BRATS_150', 'BRATS_153', 'BRATS_154', 'BRATS_155', 'BRATS_158', 'BRATS_159', 'BRATS_160', 'BRATS_162', 'BRATS_163', 'BRATS_164', 'BRATS_165', 'BRATS_166', 'BRATS_167', 'BRATS_168', 'BRATS_169', 'BRATS_170', 'BRATS_171', 'BRATS_173', 'BRATS_174', 'BRATS_175', 'BRATS_177', 'BRATS_178', 'BRATS_179', 'BRATS_180', 'BRATS_182', 'BRATS_183', 'BRATS_184', 'BRATS_185', 'BRATS_186', 'BRATS_187', 'BRATS_188', 'BRATS_189', 'BRATS_191', 'BRATS_192', 'BRATS_193', 'BRATS_195', 'BRATS_197', 'BRATS_199', 'BRATS_200', 'BRATS_201', 'BRATS_202', 'BRATS_203', 'BRATS_206', 'BRATS_207', 'BRATS_208', 'BRATS_210', 'BRATS_211', 'BRATS_212', 'BRATS_213', 'BRATS_214', 'BRATS_215', 'BRATS_216', 'BRATS_217', 'BRATS_218', 'BRATS_219', 'BRATS_222', 'BRATS_223', 'BRATS_224', 'BRATS_225', 'BRATS_226', 'BRATS_228', 'BRATS_229', 'BRATS_230', 'BRATS_231', 'BRATS_232', 'BRATS_233', 'BRATS_236', 'BRATS_237', 'BRATS_238', 'BRATS_239', 'BRATS_241', 'BRATS_243', 'BRATS_244', 'BRATS_246', 'BRATS_247', 'BRATS_248', 'BRATS_249', 'BRATS_251', 'BRATS_252', 'BRATS_253', 'BRATS_254', 'BRATS_255', 'BRATS_258', 'BRATS_259', 'BRATS_261', 'BRATS_262', 'BRATS_263', 'BRATS_264', 'BRATS_265', 'BRATS_266', 'BRATS_267', 'BRATS_268', 'BRATS_272', 'BRATS_273', 'BRATS_274', 'BRATS_275', 'BRATS_276', 'BRATS_277', 'BRATS_278', 'BRATS_279', 'BRATS_280', 'BRATS_283', 'BRATS_284', 'BRATS_285', 'BRATS_286', 'BRATS_288', 'BRATS_290', 'BRATS_293', 'BRATS_294', 'BRATS_296', 'BRATS_297', 'BRATS_298', 'BRATS_299', 'BRATS_300', 'BRATS_301', 'BRATS_302', 'BRATS_303', 'BRATS_304', 'BRATS_306', 'BRATS_307', 'BRATS_308', 'BRATS_309', 'BRATS_311', 'BRATS_312', 'BRATS_313', 'BRATS_315', 'BRATS_316', 'BRATS_317', 'BRATS_318', 'BRATS_319', 'BRATS_320', 'BRATS_321', 'BRATS_322', 'BRATS_324', 'BRATS_326', 'BRATS_328', 'BRATS_329', 'BRATS_332', 'BRATS_334', 'BRATS_335', 'BRATS_336', 'BRATS_338', 'BRATS_339', 'BRATS_340', 'BRATS_341', 'BRATS_342', 'BRATS_343', 'BRATS_344', 'BRATS_345', 'BRATS_347', 'BRATS_348', 'BRATS_349', 'BRATS_351', 'BRATS_353', 'BRATS_354', 'BRATS_355', 'BRATS_356', 'BRATS_357', 'BRATS_358', 'BRATS_359', 'BRATS_360', 'BRATS_363', 'BRATS_364', 'BRATS_365', 'BRATS_366', 'BRATS_367', 'BRATS_368', 'BRATS_369', 'BRATS_370', 'BRATS_371', 'BRATS_372', 'BRATS_373', 'BRATS_374', 'BRATS_375', 'BRATS_376', 'BRATS_377', 'BRATS_378', 'BRATS_379', 'BRATS_380', 'BRATS_381', 'BRATS_383', 'BRATS_384', 'BRATS_385', 'BRATS_386', 'BRATS_387', 'BRATS_388', 'BRATS_390', 'BRATS_391', 'BRATS_392', 'BRATS_393', 'BRATS_394', 'BRATS_395', 'BRATS_396', 'BRATS_398', 'BRATS_399', 'BRATS_401', 'BRATS_403', 'BRATS_404', 'BRATS_405', 'BRATS_407', 'BRATS_408', 'BRATS_409', 'BRATS_410', 'BRATS_411', 'BRATS_412', 'BRATS_413', 'BRATS_414', 'BRATS_415', 'BRATS_417', 'BRATS_418', 'BRATS_419', 'BRATS_420', 'BRATS_421', 'BRATS_422', 'BRATS_423', 'BRATS_424', 'BRATS_426', 'BRATS_428', 'BRATS_429', 'BRATS_430', 'BRATS_431', 'BRATS_433', 'BRATS_434', 'BRATS_435', 'BRATS_436', 'BRATS_437', 'BRATS_438', 'BRATS_439', 'BRATS_441', 'BRATS_442', 'BRATS_443', 'BRATS_444', 'BRATS_445', 'BRATS_446', 'BRATS_449', 'BRATS_451', 'BRATS_452', 'BRATS_453', 'BRATS_454', 'BRATS_455', 'BRATS_457', 'BRATS_458', 'BRATS_459', 'BRATS_460', 'BRATS_463', 'BRATS_464', 'BRATS_466', 'BRATS_467', 'BRATS_468', 'BRATS_469', 'BRATS_470', 'BRATS_472', 'BRATS_475', 'BRATS_477', 'BRATS_478', 'BRATS_481', 'BRATS_482', 'BRATS_483','BRATS_400', 'BRATS_402', 'BRATS_406', 'BRATS_416', 'BRATS_427', 'BRATS_440', 'BRATS_447', 'BRATS_448', 'BRATS_456', 'BRATS_461', 'BRATS_462', 'BRATS_465', 'BRATS_471', 'BRATS_473', 'BRATS_474', 'BRATS_476', 'BRATS_479', 'BRATS_480', 'BRATS_484']) splits[self.fold]['val'] = np.array(['BRATS_011', 'BRATS_012', 'BRATS_018', 'BRATS_020', 'BRATS_021', 'BRATS_028', 'BRATS_029', 'BRATS_032', 'BRATS_034', 'BRATS_036', 'BRATS_041', 'BRATS_047', 'BRATS_049', 'BRATS_053', 'BRATS_056', 'BRATS_057', 'BRATS_069', 'BRATS_071', 'BRATS_089', 'BRATS_090', 'BRATS_092', 'BRATS_095', 'BRATS_103', 'BRATS_105', 'BRATS_106', 'BRATS_107', 'BRATS_109', 'BRATS_118', 'BRATS_145', 'BRATS_147', 'BRATS_156', 'BRATS_161', 'BRATS_172', 'BRATS_176', 'BRATS_181', 'BRATS_194', 'BRATS_196', 'BRATS_198', 'BRATS_204', 'BRATS_205', 'BRATS_209', 'BRATS_220', 'BRATS_221', 'BRATS_227', 'BRATS_234', 'BRATS_235', 'BRATS_245', 'BRATS_250', 'BRATS_256', 'BRATS_257', 'BRATS_260', 'BRATS_269', 'BRATS_270', 'BRATS_271', 'BRATS_281', 'BRATS_282', 'BRATS_287', 'BRATS_289', 'BRATS_291', 'BRATS_292', 'BRATS_310', 'BRATS_314', 'BRATS_323', 'BRATS_327', 'BRATS_330', 'BRATS_333', 'BRATS_337', 'BRATS_346', 'BRATS_350', 'BRATS_352', 'BRATS_361', 'BRATS_382', 'BRATS_397']) if self.fold < len(splits): tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] self.print_to_log_file("This split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) else: self.print_to_log_file("INFO: You requested fold %d for training but splits " "contain only %d folds. I am now creating a " "random (but seeded) 80:20 split!" % (self.fold, len(splits))) # if we request a fold that is not in the split file, create a random 80:20 split rnd = np.random.RandomState(seed=12345 + self.fold) keys = np.sort(list(self.dataset.keys())) idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) idx_val = [i for i in range(len(keys)) if i not in idx_tr] tr_keys = [keys[i] for i in idx_tr] val_keys = [keys[i] for i in idx_val] self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) tr_keys.sort() val_keys.sort() self.dataset_tr = OrderedDict() for i in tr_keys: self.dataset_tr[i] = self.dataset[i] self.dataset_val = OrderedDict() for i in val_keys: self.dataset_val[i] = self.dataset[i] def setup_DA_params(self): """ - we increase roation angle from [-15, 15] to [-30, 30] - scale range is now (0.7, 1.4), was (0.85, 1.25) - we don't do elastic deformation anymore :return: """ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] if self.threeD: self.data_aug_params = default_3D_augmentation_params self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) if self.do_dummy_2D_aug: self.data_aug_params["dummy_2D"] = True self.print_to_log_file("Using dummy2d data augmentation") self.data_aug_params["elastic_deform_alpha"] = \ default_2D_augmentation_params["elastic_deform_alpha"] self.data_aug_params["elastic_deform_sigma"] = \ default_2D_augmentation_params["elastic_deform_sigma"] self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"] else: self.do_dummy_2D_aug = False if max(self.patch_size) / min(self.patch_size) > 1.5: default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) self.data_aug_params = default_2D_augmentation_params self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm if self.do_dummy_2D_aug: self.basic_generator_patch_size = get_patch_size(self.patch_size[1:], self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size)) patch_size_for_spatialtransform = self.patch_size[1:] else: self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'], self.data_aug_params['rotation_z'], self.data_aug_params['scale_range']) patch_size_for_spatialtransform = self.patch_size self.data_aug_params["scale_range"] = (0.7, 1.4) self.data_aug_params["do_elastic"] = False self.data_aug_params['selected_seg_channels'] = [0] self.data_aug_params['patch_size_for_spatialtransform'] = patch_size_for_spatialtransform self.data_aug_params["num_cached_per_thread"] = 2 def maybe_update_lr(self, epoch=None): """ if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1 (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented. herefore we need to do +1 here) :param epoch: :return: """ if epoch is None: ep = self.epoch + 1 else: ep = epoch self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9) self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6)) def on_epoch_end(self): """ overwrite patient-based early stopping. Always run to 1000 epochs :return: """ super().on_epoch_end() continue_training = self.epoch < self.max_num_epochs # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95 if self.epoch == 100: if self.all_val_eval_metrics[-1] == 0: self.optimizer.param_groups[0]["momentum"] = 0.95 self.network.apply(InitWeights_He(1e-2)) self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too " "high momentum. High momentum (0.99) is good for datasets where it works, but " "sometimes causes issues such as this one. Momentum has now been reduced to " "0.95 and network weights have been reinitialized") return continue_training def run_training(self): """ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first continued epoch with self.initial_lr we also need to make sure deep supervision in the network is enabled for training, thus the wrapper :return: """ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we # want at the start of the training ds = self.network.do_ds if self.deep_supervision: self.network.do_ds = True else: self.network.do_ds = False ret = super().run_training() self.network.do_ds = ds return ret ================================================ FILE: unetr_pp/training/optimizer/ranger.py ================================================ ############ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer # This code was taken from the repo above and was not created by me (Fabian)! Full credit goes to the original authors ############ import math import torch from torch.optim.optimizer import Optimizer class Ranger(Optimizer): def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0): # parameter checks if not 0.0 <= alpha <= 1.0: raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') if not lr > 0: raise ValueError(f'Invalid Learning Rate: {lr}') if not eps > 0: raise ValueError(f'Invalid eps: {eps}') # parameter comments: # beta1 (momentum) of .95 seems to work better than .90... # N_sma_threshold of 5 seems better in testing than 4. # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. # prep defaults and init torch.optim base defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) # adjustable threshold self.N_sma_threshhold = N_sma_threshhold # now we can get to work... # removed as we now use step from RAdam...no need for duplicate step counting # for group in self.param_groups: # group["step_counter"] = 0 # print("group step counter init") # look ahead params self.alpha = alpha self.k = k # radam buffer for state self.radam_buffer = [[None, None, None] for ind in range(10)] # self.first_run_check=0 # lookahead weights # 9/2/19 - lookahead param tensors have been moved to state storage. # This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. # self.slow_weights = [[p.clone().detach() for p in group['params']] # for group in self.param_groups] # don't use grad for lookahead weights # for w in it.chain(*self.slow_weights): # w.requires_grad = False def __setstate__(self, state): print("set state called") super(Ranger, self).__setstate__(state) def step(self, closure=None): loss = None # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. # Uncomment if you need to use the actual closure... # if closure is not None: # loss = closure() # Evaluate averages and grad, update param tensors for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('Ranger optimizer does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] # get state dict for this param if len(state) == 0: # if first time to run...init dictionary with our desired entries # if self.first_run_check==0: # self.first_run_check=1 # print("Initializing slow buffer...should not see this at load from saved model!") state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) # look ahead weight storage now in state dict state['slow_buffer'] = torch.empty_like(p.data) state['slow_buffer'].copy_(p.data) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) # begin computations exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] # compute variance mov avg exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) # compute mean moving avg exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 buffered = self.radam_buffer[int(state['step'] % 10)] if state['step'] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state['step'] beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) buffered[1] = N_sma if N_sma > self.N_sma_threshhold: step_size = math.sqrt( (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( N_sma_max - 2)) / (1 - beta1 ** state['step']) else: step_size = 1.0 / (1 - beta1 ** state['step']) buffered[2] = step_size if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) if N_sma > self.N_sma_threshhold: denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) else: p_data_fp32.add_(-step_size * group['lr'], exp_avg) p.data.copy_(p_data_fp32) # integrated look ahead... # we do it at the param level instead of group level if state['step'] % group['k'] == 0: slow_p = state['slow_buffer'] # get access to slow param tensor slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor return loss ================================================ FILE: unetr_pp/utilities/__init__.py ================================================ from __future__ import absolute_import from . import * ================================================ FILE: unetr_pp/utilities/distributed.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import distributed from torch import autograd from torch.nn.parallel import DistributedDataParallel as DDP def print_if_rank0(*args): if distributed.get_rank() == 0: print(*args) class awesome_allgather_function(autograd.Function): @staticmethod def forward(ctx, input): world_size = distributed.get_world_size() # create a destination list for the allgather. I'm assuming you're gathering from 3 workers. allgather_list = [torch.empty_like(input) for _ in range(world_size)] #if distributed.get_rank() == 0: # import IPython;IPython.embed() distributed.all_gather(allgather_list, input) return torch.cat(allgather_list, dim=0) @staticmethod def backward(ctx, grad_output): #print_if_rank0("backward grad_output len", len(grad_output)) #print_if_rank0("backward grad_output shape", grad_output.shape) grads_per_rank = grad_output.shape[0] // distributed.get_world_size() rank = distributed.get_rank() # We'll receive gradients for the entire catted forward output, so to mimic DataParallel, # return only the slice that corresponds to this process's input: sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank) #print("worker", rank, "backward slice", sl) return grad_output[sl] if __name__ == "__main__": import torch.distributed as dist import argparse from torch import nn from torch.optim import Adam argumentparser = argparse.ArgumentParser() argumentparser.add_argument("--local_rank", type=int) args = argumentparser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') rnd = torch.rand((5, 2)).cuda() rnd_gathered = awesome_allgather_function.apply(rnd) print("gathering random tensors\nbefore\b", rnd, "\nafter\n", rnd_gathered) # so far this works as expected print("now running a DDP model") c = nn.Conv2d(2, 3, 3, 1, 1, 1, 1, True).cuda() c = DDP(c) opt = Adam(c.parameters()) bs = 5 if dist.get_rank() == 0: bs = 4 inp = torch.rand((bs, 2, 5, 5)).cuda() out = c(inp) print("output_shape", out.shape) out_gathered = awesome_allgather_function.apply(out) print("output_shape_after_gather", out_gathered.shape) # this also works loss = out_gathered.sum() loss.backward() opt.step() ================================================ FILE: unetr_pp/utilities/file_conversions.py ================================================ from typing import Tuple, List, Union from skimage import io import SimpleITK as sitk import numpy as np import tifffile def convert_2d_image_to_nifti(input_filename: str, output_filename_truncated: str, spacing=(999, 1, 1), transform=None, is_seg: bool = False) -> None: """ Reads an image (must be a format that it recognized by skimage.io.imread) and converts it into a series of niftis. The image can have an arbitrary number of input channels which will be exported separately (_0000.nii.gz, _0001.nii.gz, etc for images and only .nii.gz for seg). Spacing can be ignored most of the time. !!!2D images are often natural images which do not have a voxel spacing that could be used for resampling. These images must be resampled by you prior to converting them to nifti!!! Datasets converted with this utility can only be used with the 2d U-Net configuration of nnU-Net If Transform is not None it will be applied to the image after loading. Segmentations will be converted to np.uint32! :param is_seg: :param transform: :param input_filename: :param output_filename_truncated: do not use a file ending for this one! Example: output_name='./converted/image1'. This function will add the suffix (_0000) and file ending (.nii.gz) for you. :param spacing: :return: """ img = io.imread(input_filename) if transform is not None: img = transform(img) if len(img.shape) == 2: # 2d image with no color channels img = img[None, None] # add dimensions else: assert len(img.shape) == 3, "image should be 3d with color channel last but has shape %s" % str(img.shape) # we assume that the color channel is the last dimension. Transpose it to be in first img = img.transpose((2, 0, 1)) # add third dimension img = img[:, None] # image is now (c, x, x, z) where x=1 since it's 2d if is_seg: assert img.shape[0] == 1, 'segmentations can only have one color channel, not sure what happened here' for j, i in enumerate(img): if is_seg: i = i.astype(np.uint32) itk_img = sitk.GetImageFromArray(i) itk_img.SetSpacing(list(spacing)[::-1]) if not is_seg: sitk.WriteImage(itk_img, output_filename_truncated + "_%04.0d.nii.gz" % j) else: sitk.WriteImage(itk_img, output_filename_truncated + ".nii.gz") def convert_3d_tiff_to_nifti(filenames: List[str], output_name: str, spacing: Union[tuple, list], transform=None, is_seg=False) -> None: """ filenames must be a list of strings, each pointing to a separate 3d tiff file. One file per modality. If your data only has one imaging modality, simply pass a list with only a single entry Files in filenames must be readable with Note: we always only pass one file into tifffile.imread, not multiple (even though it supports it). This is because I am not familiar enough with this functionality and would like to have control over what happens. If Transform is not None it will be applied to the image after loading. :param transform: :param filenames: :param output_name: :param spacing: :return: """ if is_seg: assert len(filenames) == 1 for j, i in enumerate(filenames): img = tifffile.imread(i) if transform is not None: img = transform(img) itk_img = sitk.GetImageFromArray(img) itk_img.SetSpacing(list(spacing)[::-1]) if not is_seg: sitk.WriteImage(itk_img, output_name + "_%04.0d.nii.gz" % j) else: sitk.WriteImage(itk_img, output_name + ".nii.gz") def convert_2d_segmentation_nifti_to_img(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8): img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file)) assert img.shape[0] == 1, "This function can only export 2D segmentations!" img = img[0] if transform is not None: img = transform(img) io.imsave(output_filename, img.astype(export_dtype), check_contrast=False) def convert_3d_segmentation_nifti_to_tiff(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8): img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file)) assert len(img.shape) == 3, "This function can only export 3D segmentations!" if transform is not None: img = transform(img) tifffile.imsave(output_filename, img.astype(export_dtype)) ================================================ FILE: unetr_pp/utilities/file_endings.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * def remove_trailing_slash(filename: str): while filename.endswith('/'): filename = filename[:-1] return filename def maybe_add_0000_to_all_niigz(folder): nii_gz = subfiles(folder, suffix='.nii.gz') for n in nii_gz: n = remove_trailing_slash(n) if not n.endswith('_0000.nii.gz'): os.rename(n, n[:-7] + '_0000.nii.gz') ================================================ FILE: unetr_pp/utilities/folder_names.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import network_training_output_dir def get_output_folder_name(model: str, task: str = None, trainer: str = None, plans: str = None, fold: int = None, overwrite_training_output_dir: str = None): """ Retrieves the correct output directory for the nnU-Net model described by the input parameters :param model: :param task: :param trainer: :param plans: :param fold: :param overwrite_training_output_dir: :return: """ assert model in ["2d", "3d_cascade_fullres", '3d_fullres', '3d_lowres'] if overwrite_training_output_dir is not None: tr_dir = overwrite_training_output_dir else: tr_dir = network_training_output_dir current = join(tr_dir, model) if task is not None: current = join(current, task) if trainer is not None and plans is not None: current = join(current, trainer + "__" + plans) if fold is not None: current = join(current, "fold_%d" % fold) return current ================================================ FILE: unetr_pp/utilities/nd_softmax.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn import torch.nn.functional as F softmax_helper = lambda x: F.softmax(x, 1) ================================================ FILE: unetr_pp/utilities/one_hot_encoding.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np def to_one_hot(seg, all_seg_labels=None): if all_seg_labels is None: all_seg_labels = np.unique(seg) result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype) for i, l in enumerate(all_seg_labels): result[i][seg == l] = 1 return result ================================================ FILE: unetr_pp/utilities/overlay_plots.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from multiprocessing.pool import Pool import numpy as np import SimpleITK as sitk from unetr_pp.utilities.task_name_id_conversion import convert_task_name_to_id, convert_id_to_task_name from batchgenerators.utilities.file_and_folder_operations import * from unetr_pp.paths import * color_cycle = ( "000000", "4363d8", "f58231", "3cb44b", "e6194B", "911eb4", "ffe119", "bfef45", "42d4f4", "f032e6", "000075", "9A6324", "808000", "800000", "469990", ) def hex_to_rgb(hex: str): assert len(hex) == 6 return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4)) def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, color_cycle=color_cycle, overlay_intensity=0.6): """ image must be a color image, so last dimension must be 3. if image is grayscale, tile it first! Segmentation must be label map of same shape as image (w/o color channels) mapping can be label_id -> idx_in_cycle or None returned image is scaled to [0, 255]!!! """ # assert len(image.shape) == len(segmentation.shape) # assert all([i == j for i, j in zip(image.shape, segmentation.shape)]) # create a copy of image image = np.copy(input_image) if len(image.shape) == 2: image = np.tile(image[:, :, None], (1, 1, 3)) elif len(image.shape) == 3: assert image.shape[2] == 3, 'if 3d image is given the last dimension must be the color channels ' \ '(3 channels). Only 2D images are supported' else: raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in " "last dimension) are supported") # rescale image to [0, 255] image = image - image.min() image = image / image.max() * 255 # create output if mapping is None: uniques = np.unique(segmentation) mapping = {i: c for c, i in enumerate(uniques)} for l in mapping.keys(): image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]])) # rescale result to [0, 255] image = image / image.max() * 255 return image.astype(np.uint8) def plot_overlay(image_file: str, segmentation_file: str, output_file: str, overlay_intensity: float = 0.6): import matplotlib.pyplot as plt image = sitk.GetArrayFromImage(sitk.ReadImage(image_file)) seg = sitk.GetArrayFromImage(sitk.ReadImage(segmentation_file)) assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % ( image_file, segmentation_file) assert len(image.shape) == 3, 'only 3D images/segs are supported' fg_mask = seg != 0 fg_per_slice = fg_mask.sum((1, 2)) selected_slice = np.argmax(fg_per_slice) overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) plt.imsave(output_file, overlay) def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, modality_index=0): import matplotlib.pyplot as plt data = np.load(case_file)['data'] assert modality_index < (data.shape[0] - 1), 'This dataset only supports modality index up to %d' % (data.shape[0] - 2) image = data[modality_index] seg = data[-1] seg[seg < 0] = 0 fg_mask = seg > 0 fg_per_slice = fg_mask.sum((1, 2)) selected_slice = np.argmax(fg_per_slice) overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) plt.imsave(output_file, overlay) def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, list_of_output_files, overlay_intensity, num_processes=8): p = Pool(num_processes) r = p.starmap_async(plot_overlay, zip( list_of_image_files, list_of_seg_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files) )) r.get() p.close() p.join() def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity, num_processes=8, modality_index=0): p = Pool(num_processes) r = p.starmap_async(plot_overlay_preprocessed, zip( list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files), [modality_index] * len(list_of_output_files) )) r.get() p.close() p.join() def generate_overlays_for_task(task_name_or_id, output_folder, num_processes=8, modality_idx=0, use_preprocessed=True, data_identifier=default_data_identifier): if isinstance(task_name_or_id, str): if not task_name_or_id.startswith("Task"): task_name_or_id = int(task_name_or_id) task_name = convert_id_to_task_name(task_name_or_id) else: task_name = task_name_or_id else: task_name = convert_id_to_task_name(int(task_name_or_id)) if not use_preprocessed: folder = join(nnFormer_raw_data, task_name) identifiers = [i[:-7] for i in subfiles(join(folder, 'labelsTr'), suffix='.nii.gz', join=False)] image_files = [join(folder, 'imagesTr', i + "_%04.0d.nii.gz" % modality_idx) for i in identifiers] seg_files = [join(folder, 'labelsTr', i + ".nii.gz") for i in identifiers] assert all([isfile(i) for i in image_files]) assert all([isfile(i) for i in seg_files]) maybe_mkdir_p(output_folder) output_files = [join(output_folder, i + '.png') for i in identifiers] multiprocessing_plot_overlay(image_files, seg_files, output_files, 0.6, num_processes) else: folder = join(preprocessing_output_dir, task_name) if not isdir(folder): raise RuntimeError("run preprocessing for that task first") matching_folders = subdirs(folder, prefix=data_identifier + "_stage") if len(matching_folders) == 0: "run preprocessing for that task first (use default experiment planner!)" matching_folders.sort() folder = matching_folders[-1] identifiers = [i[:-4] for i in subfiles(folder, suffix='.npz', join=False)] maybe_mkdir_p(output_folder) output_files = [join(output_folder, i + '.png') for i in identifiers] image_files = [join(folder, i + ".npz") for i in identifiers] maybe_mkdir_p(output_folder) multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=0.6, num_processes=num_processes, modality_index=modality_idx) def entry_point_generate_overlay(): import argparse parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this " "disregards spacing information!") parser.add_argument('-t', type=str, help="task name or task ID", required=True) parser.add_argument('-o', type=str, help="output folder", required=True) parser.add_argument('-num_processes', type=int, default=8, required=False, help="number of processes used. Default: 8") parser.add_argument('-modality_idx', type=int, default=0, required=False, help="modality index used (0 = _0000.nii.gz). Default: 0") parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else " "we use preprocessed") args = parser.parse_args() generate_overlays_for_task(args.t, args.o, args.num_processes, args.modality_idx, use_preprocessed=not args.use_raw) ================================================ FILE: unetr_pp/utilities/random_stuff.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class no_op(object): def __enter__(self): pass def __exit__(self, *args): pass ================================================ FILE: unetr_pp/utilities/recursive_delete_npz.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * import argparse import os def recursive_delete_npz(current_directory: str): npz_files = subfiles(current_directory, join=True, suffix=".npz") npz_files = [i for i in npz_files if not i.endswith("segFromPrevStage.npz")] # to be extra safe _ = [os.remove(i) for i in npz_files] for d in subdirs(current_directory, join=False): if d != "pred_next_stage": recursive_delete_npz(join(current_directory, d)) if __name__ == "__main__": parser = argparse.ArgumentParser(usage="USE THIS RESPONSIBLY! DANGEROUS! I (Fabian) use this to remove npz files " "after I ran figure_out_what_to_submit") parser.add_argument("-f", help="folder", required=True) args = parser.parse_args() recursive_delete_npz(args.f) ================================================ FILE: unetr_pp/utilities/recursive_rename_taskXX_to_taskXXX.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from batchgenerators.utilities.file_and_folder_operations import * import os def recursive_rename(folder): s = subdirs(folder, join=False) for ss in s: if ss.startswith("Task") and ss.find("_") == 6: task_id = int(ss[4:6]) name = ss[7:] os.rename(join(folder, ss), join(folder, "Task%03.0d_" % task_id + name)) s = subdirs(folder, join=True) for ss in s: recursive_rename(ss) if __name__ == "__main__": recursive_rename("/media/fabian/Results/nnFormer") recursive_rename("/media/fabian/unetr_pp") recursive_rename("/media/fabian/My Book/MedicalDecathlon") recursive_rename("/home/fabian/drives/datasets/nnFormer_raw") recursive_rename("/home/fabian/drives/datasets/nnFormer_preprocessed") recursive_rename("/home/fabian/drives/datasets/nnFormer_testSets") recursive_rename("/home/fabian/drives/datasets/results/nnFormer") recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/Decathlon_raw") recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/nnFormer_preprocessed") ================================================ FILE: unetr_pp/utilities/sitk_stuff.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import SimpleITK as sitk def copy_geometry(image: sitk.Image, ref: sitk.Image): image.SetOrigin(ref.GetOrigin()) image.SetDirection(ref.GetDirection()) image.SetSpacing(ref.GetSpacing()) return image ================================================ FILE: unetr_pp/utilities/task_name_id_conversion.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unetr_pp.paths import nnFormer_raw_data, preprocessing_output_dir, nnFormer_cropped_data, network_training_output_dir from batchgenerators.utilities.file_and_folder_operations import * import numpy as np def convert_id_to_task_name(task_id: int): startswith = "Task%03.0d" % task_id if preprocessing_output_dir is not None: candidates_preprocessed = subdirs(preprocessing_output_dir, prefix=startswith, join=False) else: candidates_preprocessed = [] if nnFormer_raw_data is not None: candidates_raw = subdirs(nnFormer_raw_data, prefix=startswith, join=False) else: candidates_raw = [] if nnFormer_cropped_data is not None: candidates_cropped = subdirs(nnFormer_cropped_data, prefix=startswith, join=False) else: candidates_cropped = [] candidates_trained_models = [] if network_training_output_dir is not None: for m in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']: if isdir(join(network_training_output_dir, m)): candidates_trained_models += subdirs(join(network_training_output_dir, m), prefix=startswith, join=False) all_candidates = candidates_cropped + candidates_preprocessed + candidates_raw + candidates_trained_models unique_candidates = np.unique(all_candidates) if len(unique_candidates) > 1: raise RuntimeError("More than one task name found for task id %d. Please correct that. (I looked in the " "following folders:\n%s\n%s\n%s" % (task_id, nnFormer_raw_data, preprocessing_output_dir, nnFormer_cropped_data)) if len(unique_candidates) == 0: raise RuntimeError("Could not find a task with the ID %d. Make sure the requested task ID exists and that " "nnU-Net knows where raw and preprocessed data are located (see Documentation - " "Installation). Here are your currently defined folders:\nunetr_pp_preprocessed=%s\nRESULTS_" "FOLDER=%s\nunetr_pp_raw_data_base=%s\nIf something is not right, adapt your environemnt " "variables." % (task_id, os.environ.get('unetr_pp_preprocessed') if os.environ.get('unetr_pp_preprocessed') is not None else 'None', os.environ.get('RESULTS_FOLDER') if os.environ.get('RESULTS_FOLDER') is not None else 'None', os.environ.get('unetr_pp_raw_data_base') if os.environ.get('unetr_pp_raw_data_base') is not None else 'None', )) return unique_candidates[0] def convert_task_name_to_id(task_name: str): assert task_name.startswith("Task") task_id = int(task_name[4:7]) return task_id ================================================ FILE: unetr_pp/utilities/tensor_utilities.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from torch import nn def sum_tensor(inp, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: inp = inp.sum(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): inp = inp.sum(int(ax)) return inp def mean_tensor(inp, axes, keepdim=False): axes = np.unique(axes).astype(int) if keepdim: for ax in axes: inp = inp.mean(int(ax), keepdim=True) else: for ax in sorted(axes, reverse=True): inp = inp.mean(int(ax)) return inp def flip(x, dim): """ flips the tensor at dimension dim (mirroring!) :param x: :param dim: :return: """ indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] ================================================ FILE: unetr_pp/utilities/to_torch.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def maybe_to_torch(d): if isinstance(d, list): d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] elif not isinstance(d, torch.Tensor): d = torch.from_numpy(d).float() return d def to_cuda(data, non_blocking=True, gpu_id=0): if isinstance(data, list): data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] else: data = data.cuda(gpu_id, non_blocking=non_blocking) return data