Repository: MIC-DKFZ/nnUNet Branch: master Commit: 4e8361bd9f2a Files: 237 Total size: 1.1 MB Directory structure: gitextract_vkpapda3/ ├── .github/ │ └── workflows/ │ ├── codespell.yml │ └── run_tests_nnunet.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── documentation/ │ ├── __init__.py │ ├── benchmarking.md │ ├── changelog.md │ ├── competitions/ │ │ ├── AortaSeg24.md │ │ ├── AutoPETII.md │ │ ├── FLARE24/ │ │ │ ├── Task_1/ │ │ │ │ ├── __init__.py │ │ │ │ ├── inference_flare_task1.py │ │ │ │ └── readme.md │ │ │ ├── Task_2/ │ │ │ │ ├── __init__.py │ │ │ │ ├── inference_flare_task2.py │ │ │ │ └── readme.md │ │ │ └── __init__.py │ │ ├── Toothfairy2/ │ │ │ ├── __init__.py │ │ │ ├── inference_script_semseg_only_customInf2.py │ │ │ └── readme.md │ │ └── __init__.py │ ├── convert_msd_dataset.md │ ├── dataset_format.md │ ├── dataset_format_inference.md │ ├── explanation_logging.md │ ├── explanation_normalization.md │ ├── explanation_plans_files.md │ ├── extending_nnunet.md │ ├── how_to_use_nnunet.md │ ├── ignore_label.md │ ├── installation_instructions.md │ ├── manual_data_splits.md │ ├── pretraining_and_finetuning.md │ ├── region_based_training.md │ ├── resenc_presets.md │ ├── run_inference_with_pretrained_models.md │ ├── set_environment_variables.md │ ├── setting_up_paths.md │ └── tldr_migration_guide_from_v1.md ├── nnunetv2/ │ ├── __init__.py │ ├── batch_running/ │ │ ├── __init__.py │ │ ├── benchmarking/ │ │ │ ├── __init__.py │ │ │ ├── generate_benchmarking_commands.py │ │ │ └── summarize_benchmark_results.py │ │ ├── collect_results_custom_Decathlon.py │ │ ├── collect_results_custom_Decathlon_2d.py │ │ ├── generate_lsf_runs_customDecathlon.py │ │ ├── jobs.sh │ │ └── release_trainings/ │ │ ├── __init__.py │ │ └── nnunetv2_v1/ │ │ ├── __init__.py │ │ ├── collect_results.py │ │ └── generate_lsf_commands.py │ ├── configuration.py │ ├── dataset_conversion/ │ │ ├── Dataset015_018_RibFrac_RibSeg.py │ │ ├── Dataset021_CTAAorta.py │ │ ├── Dataset023_AbdomenAtlas1_1Mini.py │ │ ├── Dataset027_ACDC.py │ │ ├── Dataset042_BraTS18.py │ │ ├── Dataset043_BraTS19.py │ │ ├── Dataset073_Fluo_C3DH_A549_SIM.py │ │ ├── Dataset114_MNMs.py │ │ ├── Dataset115_EMIDEC.py │ │ ├── Dataset119_ToothFairy2_All.py │ │ ├── Dataset120_RoadSegmentation.py │ │ ├── Dataset137_BraTS21.py │ │ ├── Dataset218_Amos2022_task1.py │ │ ├── Dataset219_Amos2022_task2.py │ │ ├── Dataset220_KiTS2023.py │ │ ├── Dataset221_AutoPETII_2023.py │ │ ├── Dataset223_AMOS2022postChallenge.py │ │ ├── Dataset224_AbdomenAtlas1.0.py │ │ ├── Dataset226_BraTS2024-BraTS-GLI.py │ │ ├── Dataset227_TotalSegmentatorMRI.py │ │ ├── Dataset987_dummyDataset4.py │ │ ├── Dataset989_dummyDataset4_2.py │ │ ├── __init__.py │ │ ├── convert_MSD_dataset.py │ │ ├── convert_raw_dataset_from_old_nnunet_format.py │ │ ├── datasets_for_integration_tests/ │ │ │ ├── Dataset996_IntegrationTest_Hippocampus_regions_ignore.py │ │ │ ├── Dataset997_IntegrationTest_Hippocampus_regions.py │ │ │ ├── Dataset998_IntegrationTest_Hippocampus_ignore.py │ │ │ ├── Dataset999_IntegrationTest_Hippocampus.py │ │ │ └── __init__.py │ │ └── generate_dataset_json.py │ ├── ensembling/ │ │ ├── __init__.py │ │ └── ensemble.py │ ├── evaluation/ │ │ ├── __init__.py │ │ ├── accumulate_cv_results.py │ │ ├── evaluate_predictions.py │ │ └── find_best_configuration.py │ ├── experiment_planning/ │ │ ├── __init__.py │ │ ├── dataset_fingerprint/ │ │ │ ├── __init__.py │ │ │ └── fingerprint_extractor.py │ │ ├── experiment_planners/ │ │ │ ├── __init__.py │ │ │ ├── default_experiment_planner.py │ │ │ ├── network_topology.py │ │ │ ├── resampling/ │ │ │ │ ├── __init__.py │ │ │ │ ├── planners_no_resampling.py │ │ │ │ └── resample_with_torch.py │ │ │ ├── resencUNet_planner.py │ │ │ └── residual_unets/ │ │ │ ├── __init__.py │ │ │ └── residual_encoder_unet_planners.py │ │ ├── plan_and_preprocess_api.py │ │ ├── plan_and_preprocess_entrypoints.py │ │ ├── plans_for_pretraining/ │ │ │ ├── __init__.py │ │ │ └── move_plans_between_datasets.py │ │ └── verify_dataset_integrity.py │ ├── imageio/ │ │ ├── __init__.py │ │ ├── base_reader_writer.py │ │ ├── natural_image_reader_writer.py │ │ ├── nibabel_reader_writer.py │ │ ├── reader_writer_registry.py │ │ ├── readme.md │ │ ├── simpleitk_reader_writer.py │ │ └── tif_reader_writer.py │ ├── inference/ │ │ ├── JHU_inference.py │ │ ├── __init__.py │ │ ├── data_iterators.py │ │ ├── examples.py │ │ ├── export_prediction.py │ │ ├── predict_from_raw_data.py │ │ ├── readme.md │ │ └── sliding_window_prediction.py │ ├── model_sharing/ │ │ ├── __init__.py │ │ ├── entry_points.py │ │ ├── model_download.py │ │ ├── model_export.py │ │ └── model_import.py │ ├── paths.py │ ├── postprocessing/ │ │ ├── __init__.py │ │ └── remove_connected_components.py │ ├── preprocessing/ │ │ ├── __init__.py │ │ ├── cropping/ │ │ │ ├── __init__.py │ │ │ └── cropping.py │ │ ├── normalization/ │ │ │ ├── __init__.py │ │ │ ├── default_normalization_schemes.py │ │ │ ├── map_channel_name_to_normalization.py │ │ │ └── readme.md │ │ ├── preprocessors/ │ │ │ ├── __init__.py │ │ │ └── default_preprocessor.py │ │ └── resampling/ │ │ ├── __init__.py │ │ ├── default_resampling.py │ │ ├── no_resampling.py │ │ ├── resample_torch.py │ │ └── utils.py │ ├── run/ │ │ ├── __init__.py │ │ ├── load_pretrained_weights.py │ │ └── run_training.py │ ├── tests/ │ │ ├── __init__.py │ │ └── integration_tests/ │ │ ├── __init__.py │ │ ├── add_lowres_and_cascade.py │ │ ├── cleanup_integration_test.py │ │ ├── lsf_commands.sh │ │ ├── prepare_integration_tests.sh │ │ ├── readme.md │ │ ├── run_integration_test.sh │ │ ├── run_integration_test_bestconfig_inference.py │ │ ├── run_integration_test_trainingOnly_DDP.sh │ │ └── run_nnunet_inference.py │ ├── training/ │ │ ├── __init__.py │ │ ├── data_augmentation/ │ │ │ ├── __init__.py │ │ │ ├── compute_initial_patch_size.py │ │ │ └── custom_transforms/ │ │ │ ├── __init__.py │ │ │ ├── cascade_transforms.py │ │ │ ├── deep_supervision_donwsampling.py │ │ │ ├── masking.py │ │ │ ├── region_based_training.py │ │ │ └── transforms_for_dummy_2d.py │ │ ├── dataloading/ │ │ │ ├── __init__.py │ │ │ ├── data_loader.py │ │ │ ├── nnunet_dataset.py │ │ │ └── utils.py │ │ ├── logging/ │ │ │ ├── __init__.py │ │ │ └── nnunet_logger.py │ │ ├── loss/ │ │ │ ├── __init__.py │ │ │ ├── compound_losses.py │ │ │ ├── deep_supervision.py │ │ │ ├── dice.py │ │ │ └── robust_ce_loss.py │ │ ├── lr_scheduler/ │ │ │ ├── __init__.py │ │ │ ├── polylr.py │ │ │ └── warmup.py │ │ └── nnUNetTrainer/ │ │ ├── __init__.py │ │ ├── nnUNetTrainer.py │ │ ├── primus/ │ │ │ ├── __init__.py │ │ │ └── primus_trainers.py │ │ └── variants/ │ │ ├── __init__.py │ │ ├── benchmarking/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerBenchmark_5epochs.py │ │ │ └── nnUNetTrainerBenchmark_5epochs_noDataLoading.py │ │ ├── competitions/ │ │ │ ├── __init__.py │ │ │ └── aortaseg24.py │ │ ├── data_augmentation/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerDA5.py │ │ │ ├── nnUNetTrainerDAOrd0.py │ │ │ ├── nnUNetTrainerNoDA.py │ │ │ ├── nnUNetTrainerNoMirroring.py │ │ │ └── nnUNetTrainer_noDummy2DDA.py │ │ ├── loss/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerCELoss.py │ │ │ ├── nnUNetTrainerDiceLoss.py │ │ │ └── nnUNetTrainerTopkLoss.py │ │ ├── lr_schedule/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerCosAnneal.py │ │ │ └── nnUNetTrainer_warmup.py │ │ ├── network_architecture/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerBN.py │ │ │ └── nnUNetTrainerNoDeepSupervision.py │ │ ├── optimizer/ │ │ │ ├── __init__.py │ │ │ ├── nnUNetTrainerAdam.py │ │ │ └── nnUNetTrainerAdan.py │ │ ├── sampling/ │ │ │ ├── __init__.py │ │ │ └── nnUNetTrainer_probabilisticOversampling.py │ │ └── training_length/ │ │ ├── __init__.py │ │ ├── nnUNetTrainer_Xepochs.py │ │ └── nnUNetTrainer_Xepochs_NoMirroring.py │ └── utilities/ │ ├── __init__.py │ ├── collate_outputs.py │ ├── crossval_split.py │ ├── dataset_name_id_conversion.py │ ├── ddp_allgather.py │ ├── default_n_proc_DA.py │ ├── file_path_utilities.py │ ├── find_class_by_name.py │ ├── get_network_from_plans.py │ ├── helpers.py │ ├── json_export.py │ ├── label_handling/ │ │ ├── __init__.py │ │ └── label_handling.py │ ├── network_initialization.py │ ├── overlay_plots.py │ ├── plans_handling/ │ │ ├── __init__.py │ │ └── plans_handler.py │ └── utils.py ├── pyproject.toml ├── readme.md └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/codespell.yml ================================================ --- name: Codespell on: push: branches: [master] pull_request: branches: [master] permissions: contents: read jobs: codespell: name: Check for spelling errors runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - name: Codespell uses: codespell-project/actions-codespell@v2 ================================================ FILE: .github/workflows/run_tests_nnunet.yml ================================================ name: Run nnunet tests on all OS on: push: paths-ignore: - '**.md' jobs: run-tests: strategy: matrix: # os: [ubuntu-latest, windows-latest, macos-latest] # fails on windows until https://github.com/MIC-DKFZ/nnUNet/issues/2396 is resolved os: [ubuntu-latest, macos-latest] python-version: ["3.10"] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Download and extract weights and download example file # use python instead of curl to avoid the need to install curl (more difficult for macos) run: | mkdir -p $HOME/github_actions_nnunet/results python -c "import urllib.request; urllib.request.urlretrieve('https://github.com/wasserth/TotalSegmentator/releases/download/v2.0.0-weights/Dataset300_body_6mm_1559subj.zip', '$HOME/github_actions_nnunet/results/tmp_download_file.zip')" unzip -o $HOME/github_actions_nnunet/results/tmp_download_file.zip -d $HOME/github_actions_nnunet/results rm $HOME/github_actions_nnunet/results/tmp_download_file.zip - name: Install dependencies on Ubuntu if: runner.os == 'Linux' run: | python -m pip install --upgrade pip pip install pytest Cython pip install torch==2.4.0 -f https://download.pytorch.org/whl/cpu pip install . - name: Install dependencies on Windows / MacOS if: runner.os == 'Windows' || runner.os == 'macOS' run: | python -m pip install --upgrade pip pip install pytest Cython pip install torch==2.4.0 pip install . - name: Run test script run: python nnunetv2/tests/integration_tests/run_nnunet_inference.py ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # IPython Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # dotenv .env # virtualenv venv/ ENV/ # Spyder project settings .spyderproject # Rope project settings .ropeproject *.memmap *.png *.zip *.npz *.npy *.jpg *.jpeg .idea *.txt .idea/* *.png # *.nii.gz # nifti files needed for example_data for github actions tests *.nii *.tif *.bmp *.pkl *.xml *.pkl *.pdf *.png *.jpg *.jpeg *.model !documentation/assets/scribble_example.png CLAUDE.md AGENTS.md ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to nnU-Net Thank you for your interest in contributing to nnU-Net. nnU-Net is developed and maintained by researchers at DKFZ. There is no dedicated funding or staff for maintaining the repository, and development happens alongside research and teaching responsibilities. Our bandwidth for reviewing external contributions is therefore limited, and review times may be long. ## General principles nnU-Net is intentionally designed to be focused, stable, and generally applicable across datasets and use cases. Contributions should respect this philosophy and should not introduce unnecessary complexity or specialization. New functionality must either be generally valid across datasets and setups or convincingly benefit a large enough portion of the user base. We aim to avoid bloating the framework or increasing its complexity further. ## How to contribute For larger features and refactors, please open a GitHub issue to discuss the idea before starting work. Tag @FabianIsensee so that the discussion doesn't get missed. To submit a contribution, fork the repository, make your changes on a branch, and open a pull request. ## Bug reports and bug fixes Bug reports must include a minimal reproducible example. Without a repro, it is usually impossible for us to investigate issues. Pull requests fixing bugs should also include a clear reproduction of the issue and an explanation of how the fix resolves it. ## Performance improvements If a pull request claims performance improvements, it must include benchmarks demonstrating the effect. The benchmark setup must be described clearly enough for us to reproduce the results independently. We may run additional tests ourselves before merging. ## Contributions that are unlikely to be merged To keep the framework maintainable and the workload manageable on our end, we deprioritize: - dataset-specific code - features that only apply to niche setups - narrow custom architectures or training pipelines - large refactorings without prior discussion - small PRs fixing minor typos or formatting issues ## Final note We appreciate the effort people invest in improving nnU-Net! ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: documentation/__init__.py ================================================ ================================================ FILE: documentation/benchmarking.md ================================================ # nnU-Netv2 benchmarks Does your system run like it should? Is your epoch time longer than expected? What epoch times should you expect? Look no further for we have the solution here! ## What does the nnU-netv2 benchmark do? nnU-Net's benchmark trains models for 5 epochs. At the end, the fastest epoch will be noted down, along with the GPU name, torch version and cudnn version. You can find the benchmark output in the corresponding nnUNet_results subfolder (see example below). Don't worry, we also provide scripts to collect your results. Or you just start a benchmark and look at the console output. Everything is possible. Nothing is forbidden. The benchmark implementation revolves around two trainers: - `nnUNetTrainerBenchmark_5epochs` runs a regular training for 5 epochs. When completed, writes a .json file with the fastest epoch time as well as the GPU used and the torch and cudnn versions. Useful for speed testing the entire pipeline (data loading, augmentation, GPU training) - `nnUNetTrainerBenchmark_5epochs_noDataLoading` is the same, but it doesn't do any data loading or augmentation. It just presents dummy arrays to the GPU. Useful for checking pure GPU speed. ## How to run the nnU-Netv2 benchmark? It's quite simple, actually. It looks just like a regular nnU-Net training. We provide reference numbers for some of the Medical Segmentation Decathlon datasets because they are easily accessible: [download here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2). If it needs to be quick and dirty, focus on Tasks 2 and 4. Download and extract the data and convert them to the nnU-Net format with `nnUNetv2_convert_MSD_dataset`. Run `nnUNetv2_plan_and_preprocess` for them. Then, for each dataset, run the following commands (only one per GPU! Or one after the other): ```bash nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading ``` If you want to inspect the outcome manually, check (for example!) your `nnUNet_results/DATASET_NAME/nnUNetTrainerBenchmark_5epochs__nnUNetPlans__3d_fullres/fold_0/` folder for the `benchmark_result.json` file. Note that there can be multiple entries in this file if the benchmark was run on different GPU types, torch versions or cudnn versions! If you want to summarize your results like we did in our [results](#results), check the [summary script](../nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py). Here you need to change the torch version, cudnn version and dataset you want to summarize, then execute the script. You can find the exact values you need to put there in one of your `benchmark_result.json` files. ## Results We have tested a variety of GPUs and summarized the results in a [spreadsheet](https://docs.google.com/spreadsheets/d/12Cvt_gr8XU2qWaE0XJk5jJlxMEESPxyqW0CWbQhTNNY/edit?usp=sharing). Note that you can select the torch and cudnn versions at the bottom! There may be comments in this spreadsheet. Read them! ## Result interpretation Results are shown as epoch time in seconds. Lower is better (duh). Epoch times can fluctuate between runs, so as long as you are within like 5-10% of the numbers we report, everything should be dandy. If not, here is how you can try to find the culprit! The first thing to do is to compare the performance between the `nnUNetTrainerBenchmark_5epochs_noDataLoading` and `nnUNetTrainerBenchmark_5epochs` trainers. If the difference is about the same as we report in our spreadsheet, but both your numbers are worse, the problem is with your GPU: - Are you certain you compare the correct GPU? (duh) - If yes, then you might want to install PyTorch in a different way. Never `pip install torch`! Go to the [PyTorch installation](https://pytorch.org/get-started/locally/) page, select the most recent cuda version your system supports and only then copy and execute the correct command! Either pip or conda should work - If the problem is still not fixed, we recommend you try [compiling pytorch from source](https://github.com/pytorch/pytorch#from-source). It's more difficult but that's how we roll here at the DKFZ (at least the cool kids here). - Another thing to consider is to try exactly the same torch + cudnn version as we did in our spreadsheet. Sometimes newer versions can actually degrade performance and there might be bugs from time to time. Older versions are also often a lot slower! - Finally, some very basic things that could impact your GPU performance: - Is the GPU cooled adequately? Check the temperature with `nvidia-smi`. Hot GPUs throttle performance in order to not self-destruct - Is your OS using the GPU for displaying your desktop at the same time? If so then you can expect a performance penalty (I dunno like 10% !?). That's expected and OK. - Are other users using the GPU as well? If you see a large performance difference between `nnUNetTrainerBenchmark_5epochs_noDataLoading` (fast) and `nnUNetTrainerBenchmark_5epochs` (slow) then the problem might be related to data loading and augmentation. As a reminder, nnU-net does not use pre-augmented images (offline augmentation) but instead generates augmented training samples on the fly during training (no, you cannot switch it to offline). This requires that your system can do partial reads of the image files fast enough (SSD storage required!) and that your CPU is powerful enough to run the augmentations. Check the following: - [CPU bottleneck] How many CPU threads are running during the training? nnU-Net uses 12 processes for data augmentation by default. If you see those 12 running constantly during training, consider increasing the number of processes used for data augmentation (provided there is headroom on your CPU!). Increase the number until you see less active workers than you configured (or just set the number to 32 and forget about it). You can do so by setting the `nnUNet_n_proc_DA` environment variable (Linux: `export nnUNet_n_proc_DA=24`). Read [here](set_environment_variables.md) on how to do this. If your CPU does not support more processes (setting more processes than your CPU has threads makes no sense!) you are out of luck and in desperate need of a system upgrade! - [I/O bottleneck] If you don't see 12 (or nnUNet_n_proc_DA if you set it) processes running but your training times are still slow then open up `top` (sorry, Windows users. I don't know how to do this on Windows) and look at the value left of 'wa' in the row that begins with '%Cpu (s)'. If this is >1.0 (arbitrarily set threshold here, essentially look for unusually high 'wa'. In a healthy training 'wa' will be almost 0) then your storage cannot keep up with data loading. Make sure to set nnUNet_preprocessed to a folder that is located on an SSD. nvme is preferred over SATA. PCIe3 is enough. 3000MB/s sequential read recommended. - [funky stuff] Sometimes there is funky stuff going on, especially when batch sizes are large, files are small and patch sizes are small as well. As part of the data loading process, nnU-Net needs to open and close a file for each training sample. Now imagine a dataset like Dataset004_Hippocampus where for the 2d config we have a batch size of 366 and we run 250 iterations in <10s on an A100. That's a lotta files per second (366 * 250 / 10 = 9150 files per second). Oof. If the files are on some network drive (even if it's nvme) then (probably) good night. The good news: nnU-Net has got you covered: add `export nnUNet_keep_files_open=True` to your .bashrc and the problem goes away. The neat part: it causes new problems if you are not allowed to have enough open files. You may have to increase the number of allowed open files. `ulimit -n` gives your current limit (Linux only). It should not be something like 1024. Increasing that to 65535 works well for me. See here for how to change these limits: [Link](https://kupczynski.info/posts/ubuntu-18-10-ulimits/) (works for Ubuntu 18, google for your OS!). ================================================ FILE: documentation/changelog.md ================================================ # What is different in v2? - We now support **hierarchical labels** (named regions in nnU-Net). For example, instead of training BraTS with the 'edema', 'necrosis' and 'enhancing tumor' labels you can directly train it on the target areas 'whole tumor', 'tumor core' and 'enhancing tumor'. See [here](region_based_training.md) for a detailed description + also have a look at the [BraTS 2021 conversion script](../nnunetv2/dataset_conversion/Dataset137_BraTS21.py). - Cross-platform support. Cuda, mps (Apple M1/M2) and of course CPU support! Simply select the device with `-device` in `nnUNetv2_train` and `nnUNetv2_predict`. - Unified trainer class: nnUNetTrainer. No messing around with cascaded trainer, DDP trainer, region-based trainer, ignore trainer etc. All default functionality is in there! - Supports more input/output data formats through ImageIO classes. - I/O formats can be extended by implementing new Adapters based on `BaseReaderWriter`. - The nnUNet_raw_cropped folder no longer exists -> saves disk space at no performance penalty. magic! (no jk the saving of cropped npz files was really slow, so it's actually faster to crop on the fly). - Preprocessed data and segmentation are stored in different files when unpacked. Seg is stored as int8 and thus takes 1/4 of the disk space per pixel (and I/O throughput) as in v1. - Native support for multi-GPU (DDP) TRAINING. Multi-GPU INFERENCE should still be run with `CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] -num_parts Y -part_id X`. There is no cross-GPU communication in inference, so it doesn't make sense to add additional complexity with DDP. - All nnU-Net functionality is now also accessible via API. Check the corresponding entry point in `setup.py` to see what functions you need to call. - Dataset fingerprint is now explicitly created and saved in a json file (see nnUNet_preprocessed). - Complete overhaul of plans files (read also [this](explanation_plans_files.md): - Plans are now .json and can be opened and read more easily - Configurations are explicitly named ("3d_fullres" , ...) - Configurations can inherit from each other to make manual experimentation easier - A ton of additional functionality is now included in and can be changed through the plans, for example normalization strategy, resampling etc. - Stages of the cascade are now explicitly listed in the plans. 3d_lowres has 'next_stage' (which can also be a list of configurations!). 3d_cascade_fullres has a 'previous_stage' entry. By manually editing plans files you can now connect anything you want, for example 2d with 3d_fullres or whatever. Be wild! (But don't create cycles!) - Multiple configurations can point to the same preprocessed data folder to save disk space. Careful! Only configurations that use the same spacing, resampling, normalization etc. should share a data source! By default, 3d_fullres and 3d_cascade_fullres share the same data - Any number of configurations can be added to the plans (remember to give them a unique "data_identifier"!) Folder structures are different and more user-friendly: - nnUNet_preprocessed - By default, preprocessed data is now saved as: `nnUNet_preprocessed/DATASET_NAME/PLANS_IDENTIFIER_CONFIGURATION` to clearly link them to their corresponding plans and configuration - Name of the folder containing the preprocessed images can be adapted with the `data_identifier` key. - nnUNet_results - Results are now sorted as follows: DATASET_NAME/TRAINERCLASS__PLANSIDENTIFIER__CONFIGURATION/FOLD ## What other changes are planned and not yet implemented? - Integration into MONAI (together with our friends at Nvidia) - New pretrained weights for a large number of datasets (coming very soon)) [//]: # (- nnU-Net now also natively supports an **ignore label**. Pixels with this label will not contribute to the loss. ) [//]: # (Use this to learn from sparsely annotated data, or excluding irrelevant areas from training. Read more [here](ignore_label.md).) ================================================ FILE: documentation/competitions/AortaSeg24.md ================================================ Authors: \ Maximilian Rokuss, Michael Baumgartner, Yannick Kirchhoff, Klaus H. Maier-Hein*, Fabian Isensee* *: equal contribution Author Affiliations:\ Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \ Helmholtz Imaging # Introduction This document describes our submission to the [AortaSeg24 Challenge](https://aortaseg24.grand-challenge.org/). Our model is essentially a nnU-Net ResEnc L with modified data augmentation. We disable left/right mirroring and use the heavy data augmentation [DA5 Trainer](../../nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py). Training was performed on an A100 40GB GPU. # Experiment Planning and Preprocessing After converting the data into the [nnUNet format](../../../nnUNet/documentation/dataset_format.md) (either keep and just rename the .mha files or convert them to .nii.gz), you can run the preprocessing: ```bash nnUNetv2_plan_and_preprocess -d 610 -c 3d_fullres -pl nnUNetPlannerResEncL -np 16 ``` # Training We train our model using: ```bash nnUNetv2_train 610 3d_fullres all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_DA5 ``` Models are trained from scratch. We train one model using all the images and a five fold cross validation ensemble for the submission. We recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks. Use `export nnUNet_n_proc_DA=32` or higher (if your system permits!). # Inference For inference you can use the default [nnUNet inference functionalities](../../../nnUNet/documentation/how_to_use_nnunet.md). Specifically, once the training is finished, run: ```bash nnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER -f all ``` for the single model trained on all the data and ```bash nnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER ``` for the five fold ensemble. ================================================ FILE: documentation/competitions/AutoPETII.md ================================================ # Look Ma, no code: fine tuning nnU-Net for the AutoPET II challenge by only adjusting its JSON plans Please cite our paper :-* ```text COMING SOON ``` ## Intro See the [Challenge Website](https://autopet-ii.grand-challenge.org/) for details on the challenge. Our solution to this challenge rewuires no code changes at all. All we do is optimize nnU-Net's hyperparameters (architecture, batch size, patch size) through modifying the nnUNetplans.json file. ## Prerequisites Use the latest pytorch version! We recommend you use the latest nnU-Net version as well! We ran our trainings with commit 913705f which you can try in case something doesn't work as expected: `pip install git+https://github.com/MIC-DKFZ/nnUNet.git@913705f` ## How to reproduce our trainings ### Download and convert the data 1. Download and extract the AutoPET II dataset 2. Convert it to nnU-Net format by running `python nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py FOLDER` where folder is the extracted AutoPET II dataset. ### Experiment planning and preprocessing We deviate a little from the standard nnU-Net procedure because all our experiments are based on just the 3d_fullres configuration Run the following commands: - `nnUNetv2_extract_fingerprint -d 221` extracts the dataset fingerprint - `nnUNetv2_plan_experiment -d 221` does the planning for the plain unet - `nnUNetv2_plan_experiment -d 221 -pl ResEncUNetPlanner` does the planning for the residual encoder unet - `nnUNetv2_preprocess -d 221 -c 3d_fullres` runs all the preprocessing we need ### Modification of plans files Please read the [information on how to modify plans files](../explanation_plans_files.md) first!!! It is easier to have everything in one plans file, so the first thing we do is transfer the ResEnc UNet to the default plans file. We use the configuration inheritance feature of nnU-Net to make it use the same data as the 3d_fullres configuration. Add the following to the 'configurations' dict in 'nnUNetPlans.json': ```json "3d_fullres_resenc": { "inherits_from": "3d_fullres", "network_arch_class_name": "ResidualEncoderUNet", "n_conv_per_stage_encoder": [ 1, 3, 4, 6, 6, 6 ], "n_conv_per_stage_decoder": [ 1, 1, 1, 1, 1 ] }, ``` (these values are basically just copied from the 'nnUNetResEncUNetPlans.json' file! With everything redundant being omitted thanks to inheritance from 3d_fullres) Now we crank up the patch and batch sizes. Add the following configurations: ```json "3d_fullres_resenc_bs80": { "inherits_from": "3d_fullres_resenc", "batch_size": 80 }, "3d_fullres_resenc_192x192x192_b24": { "inherits_from": "3d_fullres_resenc", "patch_size": [ 192, 192, 192 ], "batch_size": 24 } ``` Save the file (and check for potential Syntax Errors!) ### Run trainings Training each model requires 8 Nvidia A100 40GB GPUs. Expect training to run for 5-7 days. You'll need a really good CPU to handle the data augmentation! 128C/256T are a must! If you have less threads available, scale down nnUNet_n_proc_DA accordingly. ```bash nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 0 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 1 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 2 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 3 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 4 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 0 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 1 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 2 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 3 -num_gpus 8 nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 4 -num_gpus 8 ``` Done! (We also provide pretrained weights in case you don't want to invest the GPU resources, see below) ## How to make predictions with pretrained weights Our final model is an ensemble of two configurations: - ResEnc UNet with batch size 80 - ResEnc UNet with patch size 192x192x192 and batch size 24 To run inference with these models, do the following: 1. Download the pretrained model weights from [Zenodo](https://zenodo.org/record/8362371) 2. Install both .zip files using `nnUNetv2_install_pretrained_model_from_zip` 3. Make sure 4. Now you can run inference on new cases with `nnUNetv2_predict`: - `nnUNetv2_predict -i INPUT -o OUTPUT1 -d 221 -c 3d_fullres_resenc_bs80 -f 0 1 2 3 4 -step_size 0.6 --save_probabilities` - `nnUNetv2_predict -i INPUT -o OUTPUT2 -d 221 -c 3d_fullres_resenc_192x192x192_b24 -f 0 1 2 3 4 --save_probabilities` - `nnUNetv2_ensemble -i OUTPUT1 OUTPUT2 -o OUTPUT_ENSEMBLE` Note that our inference Docker omitted TTA via mirroring along the axial direction during prediction (only sagittal + coronal mirroring). This was done to keep the inference time below 10 minutes per image on a T4 GPU (we actually never tested whether we could have left this enabled). Just leave it on! You can also leave the step_size at default for the 3d_fullres_resenc_bs80. ================================================ FILE: documentation/competitions/FLARE24/Task_1/__init__.py ================================================ ================================================ FILE: documentation/competitions/FLARE24/Task_1/inference_flare_task1.py ================================================ from typing import Union, Tuple import argparse import numpy as np import os from os.path import join from pathlib import Path from time import time import torch from torch._dynamo import OptimizedModule from nnunetv2.utilities.label_handling.label_handling import LabelManager from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from batchgenerators.utilities.file_and_folder_operations import load_json import nnunetv2 from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient class FlarePredictor(nnUNetPredictor): def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], checkpoint_name: str = 'checkpoint_final.pth'): """ This is used when making predictions with a trained model """ if use_folds is None: use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) plans = load_json(join(model_training_output_dir, 'plans.json')) plans_manager = PlansManager(plans) if isinstance(use_folds, str): use_folds = [use_folds] parameters = [] for i, f in enumerate(use_folds): f = int(f) if f != 'all' else f checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), map_location=torch.device('cpu')) if i == 0: trainer_name = checkpoint['trainer_name'] configuration_name = checkpoint['init_args']['configuration'] inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None parameters.append(checkpoint['network_weights']) configuration_manager = plans_manager.get_configuration(configuration_name) # restore network num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') if trainer_class is None: raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, configuration_manager.network_arch_init_kwargs_req_import, num_input_channels, plans_manager.get_label_manager(dataset_json).num_segmentation_heads, enable_deep_supervision=False ) self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters self.network = network self.dataset_json = dataset_json self.trainer_name = trainer_name self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): self.network = torch.compile(self.network) def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], plans_manager: PlansManager, configuration_manager: ConfigurationManager, label_manager: LabelManager, properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = default_num_processes): old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, properties_dict['spacing']) # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because # apply_inference_nonlin will convert to torch # predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) segmentation = predicted_logits.argmax(0) del predicted_logits # segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) segmentation_reverted_cropping[slicer] = segmentation del segmentation # revert transpose segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) torch.set_num_threads(old_threads) return segmentation_reverted_cropping def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, configuration_manager: ConfigurationManager, plans_manager: PlansManager, dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, save_probabilities: bool = False): if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) ret = convert_predicted_logits_to_segmentation_with_correct_shape( predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, return_probabilities=save_probabilities ) del predicted_array_or_file segmentation_final = ret rw = NibabelIOWithReorient() rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], properties_dict) def predict_flare(input_dir, output_dir, model_folder, folds=("all",)): input_dir = Path(input_dir) output_dir = Path(output_dir) input_files = sorted(input_dir.glob("*.nii.gz")) output_files = [str(output_dir / f.name[:-12]) for f in input_files] for input_file, output_file in zip(input_files, output_files): print(f"Predicting {input_file.name}") start = time() plans_manager = PlansManager(load_json(join(model_folder, 'plans.json'))) configuration_manager = plans_manager.get_configuration("3d_fullres") dataset_json = load_json(join(model_folder, 'dataset.json')) rw = NibabelIOWithReorient() image, props = rw.read_images([input_file,]) with torch.no_grad(): predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False) predictor.initialize_from_trained_model_folder(model_folder, use_folds=folds) preprocessor = configuration_manager.preprocessor_class(verbose=False) data, _ = preprocessor.run_case_npy(image, None, props, plans_manager, configuration_manager, dataset_json) data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) predicted_logits = predictor.predict_logits_from_preprocessed_data(data).cpu() export_prediction_from_logits(predicted_logits, props, configuration_manager, plans_manager, dataset_json, output_file, False) print(f"Prediction time: {time() - start:.2f}s") if __name__ == '__main__': os.environ['nnUNet_compile'] = 'f' parser = argparse.ArgumentParser() parser.add_argument("-i", "--input", default="/workspace/inputs") parser.add_argument("-o", "--output", default="/workspace/outputs") parser.add_argument("-m", "--model", default="/opt/app/_trained_model") parser.add_argument("-f", "--folds", nargs="+", default=["all"]) args = parser.parse_args() predict_flare(args.input, args.output, args.model, args.folds) ================================================ FILE: documentation/competitions/FLARE24/Task_1/readme.md ================================================ Authors: \ Yannick Kirchhoff, Maximilian Rouven Rokuss, Benjamin Hamm, Ashis Ravindran, Constantin Ulrich, Klaus Maier-Hein, Fabian Isensee †: equal contribution # Introduction This document describes our contribution to [Task 1 of the FLARE24 Challenge](https://www.codabench.org/competitions/2319/). Our model is basically is a default nnU-Net trained with larger batch size of 4 and 8, respectively. We submitted the batch size 8 model and an ensemble of the batch size 4 and batch size 8 models to the final test set. # Experiment Planning and Preprocessing Bring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here: ```json { "name": "Dataset301_FLARE24Task1_labeled", "description": "Pan Cancer Segmentation", "labels": { "background": 0, "lesion": 1 }, "file_ending": ".nii.gz", "channel_names": { "0": "CT" }, "numTraining": 5000 } ``` Afterwards you can run the default nnU-Net planning and preprocessing ```bash nnUNetv2_plan_and_preprocess -d 301 -c 3d_fullres ``` ## Edit the plans files In the generated `nnUNetPlans.json` file add the following configurations ```json "3d_fullres_bs4": { "inherits_from": "3d_fullres", "batch_size": 4 }, "3d_fullres_bs8": { "inherits_from": "3d_fullres", "batch_size": 8 }, "3d_fullres_bs4u8": { "inherits_from": "3d_fullres", "batch_size": 48 } ``` Note, the last one is only used for the ensemble model during inference! # Model training Run the following commands to train the models with batch size 4 and 8. The large batch size helps stabilize the training despite the partial labels present in the dataset as well as handling the large number of scans in the dataset. We therefore keep the number of epochs at 1000. ```bash nnUNetv2_train 301 3d_fullres_bs4 all nnUNetv2_train 301 3d_fullres_bs8 all ``` # Inference Our inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command! In order to run inference with the ensemble model you need to create a folder called `nnUNetTrainer__nnUNetPlans__3d_fullres_bs4u8` in the results folder and copy the `dataset.json`, `dataset_fingerprint.json` and `plans.json` from one of the other results folder as well as the `fold_all` from both trainings as `fold_0` and `fold_1`, respectively, into this new folder. This allows for easy ensembling of both models. To run inference simply run the following commands with `folds` set to `all` for single model inference or `0 1` for the ensemble. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_fullres_bs8`. ```bash python inference_flare_task1.py -i input_folder -o output_folder -m model_folder -f folds ``` ================================================ FILE: documentation/competitions/FLARE24/Task_2/__init__.py ================================================ ================================================ FILE: documentation/competitions/FLARE24/Task_2/inference_flare_task2.py ================================================ from typing import Union, List, Tuple import argparse import itertools import multiprocessing import numpy as np import os from os.path import join from pathlib import Path from time import time import torch from torch._dynamo import OptimizedModule from tqdm import tqdm import openvino as ov import openvino.properties.hint as hints from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from acvl_utils.cropping_and_padding.padding import pad_nd_image from batchgenerators.utilities.file_and_folder_operations import load_json from nnunetv2.utilities.label_handling.label_handling import LabelManager import nnunetv2 from nnunetv2.configuration import default_num_processes from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy from nnunetv2.inference.sliding_window_prediction import compute_gaussian from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor torch.set_num_threads(multiprocessing.cpu_count()) class FlarePredictor(nnUNetPredictor): def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, perform_everything_on_device: bool = False, device: torch.device = torch.device('cpu'), verbose: bool = False, verbose_preprocessing: bool = False, allow_tqdm: bool = True, ): super().__init__(tile_step_size, use_gaussian, use_mirroring, perform_everything_on_device, device, verbose, verbose_preprocessing, allow_tqdm) if self.device == torch.device('cuda') or self.device == 'cuda': raise RuntimeError('CUDA is not supported for this task') def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], checkpoint_name: str = 'checkpoint_final.pth', save_model: bool = True): """ This is used when making predictions with a trained model """ if use_folds is None: use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) plans = load_json(join(model_training_output_dir, 'plans.json')) plans_manager = PlansManager(plans) if isinstance(use_folds, str): use_folds = [use_folds] assert len(use_folds) == 1, 'Only one fold is supported for this task' parameters = [] for i, f in enumerate(use_folds): f = int(f) if f != 'all' else f checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), map_location=torch.device('cpu')) if i == 0: trainer_name = checkpoint['trainer_name'] configuration_name = checkpoint['init_args']['configuration'] inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None if save_model: parameters.append(checkpoint['network_weights']) configuration_manager = plans_manager.get_configuration(configuration_name) if save_model: num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') if trainer_class is None: raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, configuration_manager.network_arch_init_kwargs_req_import, num_input_channels, plans_manager.get_label_manager(dataset_json).num_segmentation_heads, enable_deep_supervision=False ) self.network = network self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.dataset_json = dataset_json self.label_manager = plans_manager.get_label_manager(dataset_json) if save_model: if not isinstance(self.network, OptimizedModule): self.network.load_state_dict(parameters[0]) else: self.network._orig_mod.load_state_dict(parameters[0]) self.network.eval() config = {hints.performance_mode: hints.PerformanceMode.LATENCY, hints.enable_cpu_pinning(): True, } core = ov.Core() core.set_property( "CPU", {hints.execution_mode: hints.ExecutionMode.PERFORMANCE}, ) if save_model: input_tensor = torch.randn(1, num_input_channels, *configuration_manager.patch_size, requires_grad=False) ov_model = ov.convert_model(self.network, example_input=input_tensor) ov.save_model(ov_model, f"{model_training_output_dir}/model.xml") import sys sys.exit(0) ov_model = core.read_model(f"{model_training_output_dir}/model.xml") self.network = core.compile_model(ov_model, "CPU", config= config) def predict_from_files(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], save_probabilities: bool = False, overwrite: bool = True, num_processes_preprocessing: int = 1, num_processes_segmentation_export: int = 1, folder_with_segs_from_prev_stage: str = None, num_parts: int = 1, part_id: int = 0): """ This is nnU-Net's default function for making predictions. It works best for batch predictions (predicting many images at once). """ list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ self._manage_input_and_output_lists(list_of_lists_or_source_folder, output_folder_or_list_of_truncated_output_files, folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, save_probabilities) if len(list_of_lists_or_source_folder) == 0: return data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, seg_from_prev_stage_files, output_filename_truncated, num_processes_preprocessing) return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None if self.use_openvino: prediction = torch.from_numpy(self.network(x)[0]) else: prediction = self.network(x) if mirror_axes is not None: assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' mirror_axes = [m + 2 for m in mirror_axes] axes_combinations = [ c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) ] for axes in axes_combinations: if not self.is_openvino: prediction += torch.flip(self.network(torch.flip(x, axes)), axes) else: temp_pred = torch.from_numpy(self.network(torch.flip(x, axes))[0]) prediction += torch.flip(temp_pred, axes) prediction /= (len(axes_combinations) + 1) return prediction def _internal_predict_sliding_window_return_logits(self, data: torch.Tensor, slicers, do_on_device: bool = True, ): predicted_logits = n_predictions = prediction = gaussian = workon = None results_device = self.device if do_on_device else torch.device('cpu') try: empty_cache(self.device) # move data to device if self.verbose: print(f'move image to device {results_device}') data = data.to(results_device) # preallocate arrays if self.verbose: print(f'preallocating results arrays on device {results_device}') predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) if self.use_gaussian: gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) else: gaussian = 1 if not self.allow_tqdm and self.verbose: print(f'running prediction: {len(slicers)} steps') for sl in tqdm(slicers, disable=not self.allow_tqdm): workon = data[sl][None] workon = workon.to(self.device) prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) if self.use_gaussian: prediction *= gaussian predicted_logits[sl] += prediction n_predictions[sl[1:]] += gaussian predicted_logits /= n_predictions # check for infs if torch.any(torch.isinf(predicted_logits)): raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' 'predicted_logits to fp32') except Exception as e: del predicted_logits, n_predictions, prediction, gaussian, workon empty_cache(self.device) empty_cache(results_device) raise e return predicted_logits def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: """ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. SEE convert_predicted_logits_to_segmentation_with_correct_shape """ prediction = None if not self.use_openvino: for params in self.list_of_parameters: # messing with state dict names... if not isinstance(self.network, OptimizedModule): self.network.load_state_dict(params) else: self.network._orig_mod.load_state_dict(params) # why not leave prediction on device if perform_everything_on_device? Because this may cause the # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than # this actually saves computation time if prediction is None: prediction = self.predict_sliding_window_return_logits(data).to('cpu') else: prediction += self.predict_sliding_window_return_logits(data).to('cpu') if len(self.list_of_parameters) > 1: prediction /= len(self.list_of_parameters) else: if prediction is None: prediction = self.predict_sliding_window_return_logits(data) else: prediction += self.predict_sliding_window_return_logits(data) if self.verbose: print('Prediction done') return prediction @torch.inference_mode() def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ -> Union[np.ndarray, torch.Tensor]: assert isinstance(input_image, torch.Tensor) if self.device not in [torch.device('cpu'), 'cpu']: self.network = self.network.to(self.device) self.network.eval() empty_cache(self.device) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) # and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' if self.verbose: print(f'Input shape: {input_image.shape}') print("step_size:", self.tile_step_size) print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) # if input_image is smaller than tile_size we need to pad it to tile_size. data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, 'constant', {'value': 0}, True, None) slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) if self.perform_everything_on_device and self.device != 'cpu': # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device try: predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) except RuntimeError: print( 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') empty_cache(self.device) predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) else: predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) empty_cache(self.device) # revert padding predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] return predicted_logits def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits: Union[torch.Tensor, np.ndarray], plans_manager: PlansManager, configuration_manager: ConfigurationManager, label_manager: LabelManager, properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = default_num_processes): # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [properties_dict['spacing'][0], *configuration_manager.spacing] predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, properties_dict['spacing']) segmentation = predicted_logits.argmax(0) del predicted_logits # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) segmentation_reverted_cropping[slicer] = segmentation del segmentation # revert transpose segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) return segmentation_reverted_cropping def export_prediction_from_logits(self, predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, configuration_manager: ConfigurationManager, plans_manager: PlansManager, dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, save_probabilities: bool = False): if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) ret = self.convert_predicted_logits_to_segmentation_with_correct_shape( predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, return_probabilities=save_probabilities ) del predicted_array_or_file segmentation_final = ret rw = plans_manager.image_reader_writer_class() rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], properties_dict) def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, segmentation_previous_stage: np.ndarray = None, output_file_truncated: str = None, save_or_return_probabilities: bool = False): """ WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON. input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis ordering which cannot be disturbed in inference, otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class for loading images as was used during nnU-Net preprocessing! You can find that class in your plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the default axis ordering for medical images is the one from SimpleITK. If you load with nibabel, you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]! image_properties must only have a 'spacing' key! """ ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], [output_file_truncated], self.plans_manager, self.dataset_json, self.configuration_manager, num_threads_in_multithreaded=1, verbose=self.verbose) if self.verbose: print('preprocessing') dct = next(ppa) if self.verbose: print('predicting') predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() if self.verbose: print('resampling to original shape') self.export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager, self.plans_manager, self.dataset_json, output_file_truncated, save_or_return_probabilities) def predict_flare(input_dir, output_dir, model_folder, save_model): input_dir = Path(input_dir) output_dir = Path(output_dir) output_dir.mkdir(exist_ok=True, parents=True) input_files = list(input_dir.glob("*.nii.gz")) output_files = [str(output_dir / f.name[:-12]) for f in input_files] for input_file, output_file in zip(input_files, output_files): print(f"Predicting {input_file.name}") start = time() predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False, device=torch.device("cpu")) predictor.initialize_from_trained_model_folder(model_folder, ("all",), save_model=save_model) rw = predictor.plans_manager.image_reader_writer_class() image, props = rw.read_images([input_file,]) _ = predictor.predict_single_npy_array(image, props, None, output_file, False) print(f"Prediction time: {time() - start:.2f}s") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("-i", "--input", default="/workspace/inputs") parser.add_argument("-o", "--output", default="/workspace/outputs") parser.add_argument("-m", "--model", default="/opt/app/_trained_model") parser.add_argument("-save_model", action="store_true") args = parser.parse_args() predict_flare(args.input, args.output, args.model, args.save_model) ================================================ FILE: documentation/competitions/FLARE24/Task_2/readme.md ================================================ Authors: \ Yannick Kirchhoff*, Ashis Ravindran*, Maximilian Rouven Rokuss, Benjamin Hamm, Constantin Ulrich, Klaus Maier-Hein, Fabian Isensee *: equal contribution \ †: equal contribution # Introduction This document describes our contribution to [Task 2 of the FLARE24 Challenge](https://www.codabench.org/competitions/2320/). Our model is basically is a default nnU-Net with a custom low resolution setting and OpenVINO optimizations for faster CPU inference. # Experiment Planning and Preprocessing Bring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here: ```json { "name": "Dataset311_FLARE24Task2_labeled", "description": "Abdominal Organ Segmentation", "labels": { "background": 0, "liver": 1, "right kidney": 2, "spleen": 3, "pancreas": 4, "aorta": 5, "ivc": 6, "rag": 7, "lag": 8, "gallbladder": 9, "esophagus": 10, "stomach": 11, "duodenum": 12, "left kidney": 13 }, "file_ending": ".nii.gz", "channel_names": { "0": "CT" }, "overwrite_image_reader_writer": "NibabelIOWithReorient", "numTraining": 50 } ``` Afterwards you can run the default nnU-Net planning and preprocessing ```bash nnUNetv2_plan_and_preprocess -d 311 -c 3d_fullres ``` ## Edit the plans files The generated `nnUNetPlans.json` file needs to be edited to incorporate the custom low resolution setting. ```json "3d_halfres": { "inherits_from": "3d_fullres", "data_identifier": "nnUNetPlans_3d_halfres", "spacing": [ 5, 1.6, 1.6 ] }, "3d_halfiso": { "inherits_from": "3d_fullres", "data_identifier": "nnUNetPlans_3d_halfiso", "spacing": [ 2.5, 2.5, 2.5 ] }, ``` `3d_halfres` is a configuration with exactly half resolution, used as an ablation of our submission, `3d_halfiso` is the isotropic configuration we submitted as a final solution. # Model training Run one of the following commands to train the respective configurations. `3d_halfiso` yielded significantly better results in our experiments as well as on the final test set and is the recommended configuration. ```bash nnUNetv2_train 311 3d_halfres all nnUNetv2_train 311 3d_halfiso all ``` # Inference Our inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command! Inference using the provided script requires OpenVINO, which can easily be installed via ```bash pip install openvino ``` To run inference simply run the following commands. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_halfiso`. `-save_model` needs to be set to precompile the model once using OpenVINO. If not precompiled model exists, the inference script will fail! ```bash python inference_flare_task1.py -i input_folder -o output_folder -m model_folder [-save_model] ``` ================================================ FILE: documentation/competitions/FLARE24/__init__.py ================================================ ================================================ FILE: documentation/competitions/Toothfairy2/__init__.py ================================================ ================================================ FILE: documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py ================================================ import argparse import gc import os from pathlib import Path from queue import Queue from threading import Thread from typing import Union, Tuple import nnunetv2 import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from acvl_utils.cropping_and_padding.padding import pad_nd_image from batchgenerators.utilities.file_and_folder_operations import load_json, join from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.inference.sliding_window_prediction import compute_gaussian from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager from torch._dynamo import OptimizedModule from torch.backends import cudnn from tqdm import tqdm class CustomPredictor(nnUNetPredictor): def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, segmentation_previous_stage: np.ndarray = None): torch.set_num_threads(7) with torch.no_grad(): self.network = self.network.to(self.device) self.network.eval() if self.verbose: print('preprocessing') preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose) data, _, image_properties = preprocessor.run_case_npy(input_image, None, image_properties, self.plans_manager, self.configuration_manager, self.dataset_json) data = torch.from_numpy(data) del input_image if self.verbose: print('predicting') predicted_logits = self.predict_preprocessed_image(data) if self.verbose: print('Prediction done') segmentation = self.convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, image_properties) return segmentation def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], checkpoint_name: str = 'checkpoint_final.pth'): """ This is used when making predictions with a trained model """ if use_folds is None: use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) plans = load_json(join(model_training_output_dir, 'plans.json')) plans_manager = PlansManager(plans) if isinstance(use_folds, str): use_folds = [use_folds] parameters = [] for i, f in enumerate(use_folds): f = int(f) if f != 'all' else f checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), map_location=torch.device('cpu'), weights_only=False) if i == 0: trainer_name = checkpoint['trainer_name'] configuration_name = checkpoint['init_args']['configuration'] inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None parameters.append(join(model_training_output_dir, f'fold_{f}', checkpoint_name)) configuration_manager = plans_manager.get_configuration(configuration_name) # restore network num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') if trainer_class is None: raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, configuration_manager.network_arch_init_kwargs_req_import, num_input_channels, plans_manager.get_label_manager(dataset_json).num_segmentation_heads, enable_deep_supervision=False ) self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters self.network = network self.dataset_json = dataset_json self.trainer_name = trainer_name self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): print('Using torch.compile') self.network = torch.compile(self.network) @torch.inference_mode(mode=True) def predict_preprocessed_image(self, image): empty_cache(self.device) data_device = torch.device('cpu') predicted_logits_device = torch.device('cpu') gaussian_device = torch.device('cpu') compute_device = torch.device('cuda:0') data, slicer_revert_padding = pad_nd_image(image, self.configuration_manager.patch_size, 'constant', {'value': 0}, True, None) del image slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) empty_cache(self.device) data = data.to(data_device) predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=predicted_logits_device) gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=gaussian_device, dtype=torch.float16) if not self.allow_tqdm and self.verbose: print(f'running prediction: {len(slicers)} steps') for p in self.list_of_parameters: # network weights have to be updated outside autocast! # we are loading parameters on demand instead of loading them upfront. This reduces memory footprint a lot. # each set of parameters is only used once on the test set (one image) so run time wise this is almost the # same self.network.load_state_dict(torch.load(p, map_location=compute_device)['network_weights'], weights_only=False) with torch.autocast(self.device.type, enabled=True): for sl in tqdm(slicers, disable=not self.allow_tqdm): pred = self._internal_maybe_mirror_and_predict(data[sl][None].to(compute_device))[0].to( predicted_logits_device) pred /= (pred.max() / 100) predicted_logits[sl] += (pred * gaussian) del pred empty_cache(self.device) return predicted_logits def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits, props): old = torch.get_num_threads() torch.set_num_threads(7) # resample to original shape spacing_transposed = [props['spacing'][i] for i in self.plans_manager.transpose_forward] current_spacing = self.configuration_manager.spacing if \ len(self.configuration_manager.spacing) == \ len(props['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *self.configuration_manager.spacing] predicted_logits = self.configuration_manager.resampling_fn_probabilities(predicted_logits, props[ 'shape_after_cropping_and_before_resampling'], current_spacing, [props['spacing'][i] for i in self.plans_manager.transpose_forward]) segmentation = None pp = None try: with torch.no_grad(): pp = predicted_logits.to('cuda:0') segmentation = pp.argmax(0).cpu() del pp except RuntimeError: del segmentation, pp torch.cuda.empty_cache() segmentation = predicted_logits.argmax(0) del predicted_logits # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(props['shape_before_cropping'], dtype=np.uint8 if len( self.label_manager.foreground_labels) < 255 else np.uint16) slicer = bounding_box_to_slice(props['bbox_used_for_cropping']) segmentation_reverted_cropping[slicer] = segmentation del segmentation # revert transpose segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(self.plans_manager.transpose_backward) torch.set_num_threads(old) return segmentation_reverted_cropping def predict_semseg(im, prop, semseg_trained_model, semseg_folds): # initialize predictors pred_semseg = CustomPredictor( tile_step_size=0.5, use_mirroring=True, use_gaussian=True, perform_everything_on_device=False, allow_tqdm=True ) pred_semseg.initialize_from_trained_model_folder( semseg_trained_model, use_folds=semseg_folds, checkpoint_name='checkpoint_final.pth' ) semseg_pred = pred_semseg.predict_single_npy_array( im, prop, None ) torch.cuda.empty_cache() gc.collect() return semseg_pred def map_labels_to_toothfairy(predicted_seg: np.ndarray) -> np.ndarray: # Create an array that maps the labels directly max_label = 42 mapping = np.arange(max_label + 1) # Define the specific remapping remapping = {19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 26, 25: 27, 26: 28, 27: 31, 28: 32, 29: 33, 30: 34, 31: 35, 32: 36, 33: 37, 34: 38, 35: 41, 36: 42, 37: 43, 38: 44, 39: 45, 40: 46, 41: 47, 42: 48} # Apply the remapping for k, v in remapping.items(): mapping[k] = v return mapping[predicted_seg] def postprocess(prediction_npy, vol_per_voxel, verbose: bool = False): cutoffs = {1: 0.0, 2: 78411.5, 3: 0.0, 4: 0.0, 5: 2800.0, 6: 1216.5, 7: 0.0, 8: 6222.0, 9: 1573.0, 10: 946.0, 11: 0.0, 12: 6783.5, 13: 9469.5, 14: 0.0, 15: 2260.0, 16: 3566.0, 17: 6321.0, 18: 4221.5, 19: 5829.0, 20: 0.0, 21: 0.0, 22: 468.0, 23: 1555.0, 24: 1291.5, 25: 2834.5, 26: 584.5, 27: 0.0, 28: 0.0, 29: 0.0, 30: 0.0, 31: 1935.5, 32: 0.0, 33: 0.0, 34: 6140.0, 35: 0.0, 36: 0.0, 37: 0.0, 38: 2710.0, 39: 0.0, 40: 0.0, 41: 0.0, 42: 970.0} vol_per_voxel_cutoffs = 0.3 * 0.3 * 0.3 for c in cutoffs.keys(): co = cutoffs[c] if co > 0: mask = prediction_npy == c pred_vol = np.sum(mask) * vol_per_voxel if 0 < pred_vol < (co * vol_per_voxel_cutoffs): prediction_npy[mask] = 0 if verbose: print( f'removed label {c} because predicted volume of {pred_vol} is less than the cutoff {co * vol_per_voxel_cutoffs}') return prediction_npy if __name__ == '__main__': os.environ['nnUNet_compile'] = 'f' parser = argparse.ArgumentParser() parser.add_argument('-i', '--input_folder', type=Path, default="/input/images/cbct/") parser.add_argument('-o', '--output_folder', type=Path, default="/output/images/oral-pharyngeal-segmentation/") parser.add_argument('-sem_mod', '--semseg_trained_model', type=str, default="/opt/app/_trained_model/semseg_trained_model") parser.add_argument('--semseg_folds', type=str, nargs='+', default=[0, 1]) args = parser.parse_args() args.output_folder.mkdir(exist_ok=True, parents=True) semseg_folds = [i if i == 'all' else int(i) for i in args.semseg_folds] semseg_trained_model = args.semseg_trained_model rw = SimpleITKIO() input_files = list(args.input_folder.glob('*.nii.gz')) + list(args.input_folder.glob('*.mha')) for input_fname in input_files: output_fname = args.output_folder / input_fname.name # we start with the instance seg because we can then start converting that while semseg is being predicted # load test image im, prop = rw.read_images([input_fname]) with torch.no_grad(): semseg_pred = predict_semseg(im, prop, semseg_trained_model, semseg_folds) torch.cuda.empty_cache() gc.collect() # now postprocess semseg_pred = postprocess(semseg_pred, np.prod(prop['spacing']), True) semseg_pred = map_labels_to_toothfairy(semseg_pred) # now save rw.write_seg(semseg_pred, output_fname, prop) ================================================ FILE: documentation/competitions/Toothfairy2/readme.md ================================================ Authors: \ Fabian Isensee*, Yannick Kirchhoff*, Lars Kraemer, Max Rokuss, Constantin Ulrich, Klaus H. Maier-Hein *: equal contribution Author Affiliations:\ Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \ Helmholtz Imaging # Introduction This document describes our submission to the [Toothfairy2 Challenge](https://toothfairy2.grand-challenge.org/toothfairy2/). Our model is essentially a nnU-Net ResEnc L with the patch size upscaled to 160x320x320 pixels. We disable left/right mirroring and train for 1500 instead of the standard 1000 epochs. Training was either done on 2xA100 40GB or one GH200 96GB. # Dataset Conversion # Experiment Planning and Preprocessing Adapt and run the [dataset conversion script](../../../nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py). This script just converts the mha files to nifti (smaller file size) and removes the unused label ids. ## Extract fingerprint: `nnUNetv2_extract_fingerprint -d 119 -np 48` ## Run planning: `nnUNetv2_plan_experiment -d 119 -pl nnUNetPlannerResEncL_torchres` This planner not only uses the ResEncL configuration but also replaces the default resampling scheme with one that is faster (but less precise). Since all images in the challenge (train and test) should already have 0.3x0.3x0.3 spacing resampling is not required. This is just here as a safety measure. The speed is needed at inference time because grand challenge imposes a limit of 10 minutes per case. ## Edit the plans files Add the following configuration to the generated plans file: ```json "3d_fullres_torchres_ps160x320x320_bs2": { "inherits_from": "3d_fullres", "data_identifier": "nnUNetPlans_3d_fullres_torchres_ctnorm", "patch_size": [ 160, 320, 320 ], "normalization_schemes": [ "CTNormalization" ], "architecture": { "network_class_name": "dynamic_network_architectures.architectures.unet.ResidualEncoderUNet", "arch_kwargs": { "n_stages": 7, "features_per_stage": [ 32, 64, 128, 256, 320, 320, 320 ], "conv_op": "torch.nn.modules.conv.Conv3d", "kernel_sizes": [ [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ], [ 3, 3, 3 ] ], "strides": [ [ 1, 1, 1 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 2, 2, 2 ], [ 1, 2, 2 ] ], "n_blocks_per_stage": [ 1, 3, 4, 6, 6, 6, 6 ], "n_conv_per_stage_decoder": [ 1, 1, 1, 1, 1, 1 ], "conv_bias": true, "norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d", "norm_op_kwargs": { "eps": 1e-05, "affine": true }, "dropout_op": null, "dropout_op_kwargs": null, "nonlin": "torch.nn.LeakyReLU", "nonlin_kwargs": { "inplace": true } }, "_kw_requires_import": [ "conv_op", "norm_op", "dropout_op", "nonlin" ] } } ``` Aside from changing the patch size this makes the architecture one stage deeper (one more pooling + res blocks), enabling it to make effective use of the larger input # Preprocessing `nnUNetv2_preprocess -d 119 -c 3d_fullres_torchres_ps160x320x320_bs2 -plans_name nnUNetResEncUNetLPlans_torchres -np 48` # Training We train two models on all training cases: ```bash nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep nnUNet_results=${nnUNet_results}_2 nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep ``` Models are trained from scratch. Note how in the second line we overwrite the nnUNet_results variable in order to be able to train the same model twice without overwriting the results We recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks. Use `export nnUNet_n_proc_DA=32` or higher (if your system permits!). # Inference We ensemble the two models from above. On a technical level we copy the two fold_all folders into one training output directory and rename them to fold_0 and fold_1. This lets us use nnU-Net's cross-validation ensembling strategy which is more computationally efficient (needed for time limit on grand-challenge.org). Run inference with the [inference script](inference_script_semseg_only_customInf2.py) # Postprocessing If the prediction of a class on some test case is smaller than the corresponding cutoff size then it is removed (replaced with background). Cutoff values were optimized using a five-fold cross-validation on the Toothfairy2 training data. We optimize HD95 and Dice separately. The final cutoff for each class is then the smaller value between the two metrics. You can find our volume cutoffs in the inference script as part of our `postprocess` function. ================================================ FILE: documentation/competitions/__init__.py ================================================ ================================================ FILE: documentation/convert_msd_dataset.md ================================================ Use `nnUNetv2_convert_MSD_dataset`. Read `nnUNetv2_convert_MSD_dataset -h` for usage instructions. ================================================ FILE: documentation/dataset_format.md ================================================ # nnU-Net dataset format The only way to bring your data into nnU-Net is by storing it in a specific format. Due to nnU-Net's roots in the [Medical Segmentation Decathlon](http://medicaldecathlon.com/) (MSD), its dataset is heavily inspired but has since diverged (see also [here](#how-to-use-decathlon-datasets)) from the format used in the MSD. Datasets consist of three components: raw images, corresponding segmentation maps and a dataset.json file specifying some metadata. If you are migrating from nnU-Net v1, read [this](#how-to-use-nnu-net-v1-tasks) to convert your existing Tasks. ## What do training cases look like? Each training case is associated with an identifier = a unique name for that case. This identifier is used by nnU-Net to connect images with the correct segmentation. A training case consists of images and their corresponding segmentation. **Images** is plural because nnU-Net supports arbitrarily many input channels. In order to be as flexible as possible, nnU-net requires each input channel to be stored in a separate image (with the sole exception being RGB natural images). So these images could for example be a T1 and a T2 MRI (or whatever else you want). The different input channels MUST have the same geometry (same shape, spacing (if applicable) etc.) and must be co-registered (if applicable). Input channels are identified by nnU-Net by their FILE_ENDING: a four-digit integer at the end of the filename. Image files must therefore follow the following naming convention: {CASE_IDENTIFIER}_{XXXX}.{FILE_ENDING}. Hereby, XXXX is the 4-digit modality/channel identifier (should be unique for each modality/channel, e.g., “0000” for T1, “0001” for T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.png, .nii.gz, ...). See below for concrete examples. The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details). Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier. Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the [road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). **Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are integer maps with each value representing a semantic class. The background must be 0. If there is no background, then do not use the label 0 for something else! Integer values of your semantic classes must be consecutive (0, 1, 2, 3, ...). Of course, not all labels have to be present in each training case. Segmentations are saved as {CASE_IDENTIFER}.{FILE_ENDING} . Within a training case, all image geometries (input channels, corresponding segmentation) must match. Between training cases, they can of course differ. nnU-Net takes care of that. Important: The input channels must be consistent! Concretely, **all images need the same input channels in the same order and all input channels have to be present every time**. This is also true for inference! ## Supported file formats nnU-Net expects the same file format for images and segmentations! These will also be used for inference. For now, it is thus not possible to train .png and then run inference on .jpg. One big change in nnU-Net V2 is the support of multiple input file types. Gone are the days of converting everything to .nii.gz! This is implemented by abstracting the input and output of images + segmentations through `BaseReaderWriter`. nnU-Net comes with a broad collection of Readers+Writers and you can even add your own to support your data format! See [here](../nnunetv2/imageio/readme.md). As a nice bonus, nnU-Net now also natively supports 2D input images and you no longer have to mess around with conversions to pseudo 3D niftis. Yuck. That was disgusting. Note that internally (for storing and accessing preprocessed images) nnU-Net will use its own file format, irrespective of what the raw data was provided in! This is for performance reasons. By default, the following file formats are supported: - NaturalImage2DIO: .png, .bmp, .tif - NibabelIO: .nii.gz, .nrrd, .mha - NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS! - SimpleITKIO: .nii.gz, .nrrd, .mha - Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information, nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains this information (see [here](#datasetjson)). The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK support more than the three given here. The file endings given here are just the ones we tested! IMPORTANT: nnU-Net can only be used with file formats that use lossless (or no) compression! Because the file format is defined for an entire dataset (and not separately for images and segmentations, this could be a todo for the future), we must ensure that there are no compression artifacts that destroy the segmentation maps. So no .jpg and the likes! ## Dataset folder structure Datasets must be located in the `nnUNet_raw` folder (which you either define when installing nnU-Net or export/set every time you intend to run nnU-Net commands!). Each segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit integer, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and the dataset id is 5. Datasets are stored in the `nnUNet_raw` folder like this: nnUNet_raw/ ├── Dataset001_BrainTumour ├── Dataset002_Heart ├── Dataset003_Liver ├── Dataset004_Hippocampus ├── Dataset005_Prostate ├── ... Within each dataset folder, the following structure is expected: Dataset001_BrainTumour/ ├── dataset.json ├── imagesTr ├── imagesTs # optional └── labelsTr When adding your custom dataset, take a look at the [dataset_conversion](../nnunetv2/dataset_conversion) folder and pick an id that is not already taken. IDs 001-010 are for the Medical Segmentation Decathlon. - **imagesTr** contains the images belonging to the training cases. nnU-Net will perform pipeline configuration, training with cross-validation, as well as finding postprocessing and the best ensemble using this data. - **imagesTs** (optional) contains the images that belong to the test cases. nnU-Net does not use them! This could just be a convenient location for you to store these images. Remnant of the Medical Segmentation Decathlon folder structure. - **labelsTr** contains the images with the ground truth segmentation maps for the training cases. - **dataset.json** contains metadata of the dataset. The scheme introduced [above](#what-do-training-cases-look-like) results in the following folder structure. Given is an example for the first Dataset of the MSD: BrainTumour. This dataset hat four input channels: FLAIR (0000), T1w (0001), T1gd (0002) and T2w (0003). Note that the imagesTs folder is optional and does not have to be present. nnUNet_raw/Dataset001_BrainTumour/ ├── dataset.json ├── imagesTr │   ├── BRATS_001_0000.nii.gz │   ├── BRATS_001_0001.nii.gz │   ├── BRATS_001_0002.nii.gz │   ├── BRATS_001_0003.nii.gz │   ├── BRATS_002_0000.nii.gz │   ├── BRATS_002_0001.nii.gz │   ├── BRATS_002_0002.nii.gz │   ├── BRATS_002_0003.nii.gz │   ├── ... ├── imagesTs │   ├── BRATS_485_0000.nii.gz │   ├── BRATS_485_0001.nii.gz │   ├── BRATS_485_0002.nii.gz │   ├── BRATS_485_0003.nii.gz │   ├── BRATS_486_0000.nii.gz │   ├── BRATS_486_0001.nii.gz │   ├── BRATS_486_0002.nii.gz │   ├── BRATS_486_0003.nii.gz │   ├── ... └── labelsTr ├── BRATS_001.nii.gz ├── BRATS_002.nii.gz ├── ... Here is another example of the second dataset of the MSD, which has only one input channel: nnUNet_raw/Dataset002_Heart/ ├── dataset.json ├── imagesTr │   ├── la_003_0000.nii.gz │   ├── la_004_0000.nii.gz │   ├── ... ├── imagesTs │   ├── la_001_0000.nii.gz │   ├── la_002_0000.nii.gz │   ├── ... └── labelsTr ├── la_003.nii.gz ├── la_004.nii.gz ├── ... Remember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also make sure that all your data is co-registered! See also [dataset format inference](dataset_format_inference.md)!! ## dataset.json The dataset.json contains metadata that nnU-Net needs for training. We have greatly reduced the number of required fields since version 1! Here is what the dataset.json should look like at the example of the Dataset005_Prostate from the MSD: { "channel_names": { # formerly modalities "0": "T2", "1": "ADC" }, "labels": { # THIS IS DIFFERENT NOW! "background": 0, "PZ": 1, "TZ": 2 }, "numTraining": 32, "file_ending": ".nii.gz" "overwrite_image_reader_writer": "SimpleITKIO" # optional! If not provided nnU-Net will automatically determine the ReaderWriter } The channel_names determine the normalization used by nnU-Net. If a channel is marked as 'CT', then a global normalization based on the intensities in the foreground pixels will be used. If it is something else, per-channel z-scoring will be used. Refer to the methods section in [our paper](https://www.nature.com/articles/s41592-020-01008-z) for more details. nnU-Net v2 introduces a few more normalization schemes to choose from and allows you to define your own, see [here](explanation_normalization.md) for more information. Important changes relative to nnU-Net v1: - "modality" is now called "channel_names" to remove strong bias to medical images - labels are structured differently (name -> int instead of int -> name). This was needed to support [region-based training](region_based_training.md) - "file_ending" is added to support different input file types - "overwrite_image_reader_writer" optional! Can be used to specify a certain (custom) ReaderWriter class that should be used with this dataset. If not provided, nnU-Net will automatically determine the ReaderWriter - "regions_class_order" only used in [region-based training](region_based_training.md) There is a utility with which you can generate the dataset.json automatically. You can find it [here](../nnunetv2/dataset_conversion/generate_dataset_json.py). See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation! As described above, a json file that contains spacing information is required for TIFF files. An example for a 3D TIFF stack with units corresponding to 7.6 in x and y, 80 in z is: ``` { "spacing": [7.6, 7.6, 80.0] } ``` Within the dataset folder, this file (named `cell6.json` in this example) would be placed in the following folders: nnUNet_raw/Dataset123_Foo/ ├── dataset.json ├── imagesTr │   ├── cell6.json │   └── cell6_0000.tif └── labelsTr ├── cell6.json └── cell6.tif ## How to use nnU-Net v1 Tasks If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`! Example for migrating a nnU-Net v1 Task: ```bash nnUNetv2_convert_old_nnUNet_dataset /media/isensee/raw_data/nnUNet_raw_data_base/nnUNet_raw_data/Task027_ACDC Dataset027_ACDC ``` Use `nnUNetv2_convert_old_nnUNet_dataset -h` for detailed usage instructions. ## How to use decathlon datasets See [convert_msd_dataset.md](convert_msd_dataset.md) ## How to use 2D data with nnU-Net 2D is now natively supported (yay!). See [here](#supported-file-formats) as well as the example dataset in this [script](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py). ## How to update an existing dataset When updating a dataset it is best practice to remove the preprocessed data in `nnUNet_preprocessed/DatasetXXX_NAME` to ensure a fresh start. Then replace the data in `nnUNet_raw` and rerun `nnUNetv2_plan_and_preprocess`. Optionally, also remove the results from old trainings. # Example dataset conversion scripts In the `dataset_conversion` folder (see [here](../nnunetv2/dataset_conversion)) are multiple example scripts for converting datasets into nnU-Net format. These scripts cannot be run as they are (you need to open them and change some paths) but they are excellent examples for you to learn how to convert your own datasets into nnU-Net format. Just pick the dataset that is closest to yours as a starting point. The list of dataset conversion scripts is continually updated. If you find that some publicly available dataset is missing, feel free to open a PR to add it! ================================================ FILE: documentation/dataset_format_inference.md ================================================ # Data format for Inference Read the documentation on the overall [data format](dataset_format.md) first! The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly the same format as in the imagesTr folder**). As before, the filenames must start with a unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets: 1) Task005_Prostate: This task has 2 modalities, so the files in the input folder must look like this: input_folder ├── prostate_03_0000.nii.gz ├── prostate_03_0001.nii.gz ├── prostate_05_0000.nii.gz ├── prostate_05_0001.nii.gz ├── prostate_08_0000.nii.gz ├── prostate_08_0001.nii.gz ├── ... _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the dataset.json), exactly the same as was used for training. 2) Task002_Heart: imagesTs ├── la_001_0000.nii.gz ├── la_002_0000.nii.gz ├── la_006_0000.nii.gz ├── ... Task002 only has one modality, so each case only has one _0000.nii.gz file. The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier). Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training (and as was specified in 'file_ending' in the dataset.json)! ================================================ FILE: documentation/explanation_logging.md ================================================ # Logging in nnU-Net v2 ## Introduction Logging in nnU-Net is intentionally simple and centralized in `nnunetv2/training/logging/nnunet_logger.py`. The trainer talks to one object, `MetaLogger`, and `MetaLogger` fans out logs to: - `LocalLogger` (always enabled): the source of truth for training curves, checkpoint logging state, and `progress.png` - optional external loggers (currently `WandbLogger`) This keeps training code clean while still allowing external tracking backends. ## Default behaviour Without any setup, nnU-Net uses only `LocalLogger`. Per epoch, it stores: - `mean_fg_dice` and `ema_fg_dice` (EMA is computed automatically) - `dice_per_class_or_region` - `train_losses`, `val_losses` - `lrs` - `epoch_start_timestamps`, `epoch_end_timestamps` From these values, `progress.png` is updated in the fold output folder. On checkpoint save/load, the local logging state is also saved/restored, so curves continue correctly after resume. ## How to enable W&B 1. Install W&B: ```bash pip install wandb ``` 2. Enable the backend via environment variables: ```bash export nnUNet_wandb_enabled=1 export nnUNet_wandb_project=nnunet export nnUNet_wandb_mode=online # or offline ``` 3. Run training normally: ```bash nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres 0 ``` Notes: - `nnUNet_wandb_enabled` accepts `0/1` and `false/true` (case-insensitive). Other values raise an error. - When resuming (`--c`), W&B resume metadata in `fold_x/wandb/latest-run` is reused and duplicate older steps are skipped. ## How to integrate a custom logger Add a new logger class with the same minimal interface used by `MetaLogger`: - `update_config(self, config: dict)` - `log(self, key, value, step: int)` - `log_summary(self, key, value)` Example skeleton: ```python class MyLogger: def __init__(self, output_folder, resume): self.output_folder = output_folder self.resume = resume def update_config(self, config: dict): ... def log(self, key, value, step: int): ... def log_summary(self, key, value): ... ``` Then register it in `MetaLogger.__init__` (for example behind an env var switch), similar to how `WandbLogger` is added. Important integration detail: - `MetaLogger.log(...)` always writes to `LocalLogger` first. - If you introduce a brand-new per-epoch key, also add that key to `LocalLogger.my_fantastic_logging`, otherwise the local assertion will fail. ================================================ FILE: documentation/explanation_normalization.md ================================================ # Intensity normalization in nnU-Net The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`) entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on foreground intensities are supported. However, there have been a few additions as well. Reminder: The `channel_names` entry typically looks like this: "channel_names": { "0": "T2", "1": "ADC" }, It has as many entries as there are input channels for the given dataset. To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels! If you enter a channel name that is not in the following list, the default (`zscore`) will be used. Here is a list of currently available normalization schemes: - `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and 99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the standard deviation. The normalization that is applied is the same for each training case (for this input channel). The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT images and ADC maps. - `noNorm` : do not perform any normalization at all - `rescale_to_0_1`: rescale the intensities to [0, 1] - `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1] - `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case **Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If you deviate from that path, make sure to benchmark whether that actually improves results! # How to implement custom normalization strategies? - Head over to nnunetv2/preprocessing/normalization - implement a new image normalization class by deriving from ImageNormalization - register it in nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping. This is where you specify a channel name that should be associated with it - use it by specifying the correct channel_name Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme that gets multiple channels as input to be used jointly! ================================================ FILE: documentation/explanation_plans_files.md ================================================ # Modifying the nnU-Net Configurations nnU-Net provides unprecedented out-of-the-box segmentation performance for essentially any dataset we have evaluated it on. That said, there is always room for improvements. A fool-proof strategy for squeezing out the last bit of performance is to start with the default nnU-Net, and then further tune it manually to a concrete dataset at hand. **This guide is about changes to the nnU-Net configuration you can make via the plans files. It does not cover code extensions of nnU-Net. For that, take a look [here](extending_nnunet.md)** In nnU-Net V2, plans files are SO MUCH MORE powerful than they were in v1. There are a lot more knobs that you can turn without resorting to hacky solutions or even having to touch the nnU-Net code at all! And as an added bonus: plans files are now also .json files and no longer require users to fiddle with pickle. Just open them in your text editor of choice! If overwhelmed, look at our [Examples](#examples)! # plans.json structure Plans have global and local settings. Global settings are applied to all configurations in that plans file while local settings are attached to a specific configuration. ## Global settings - `foreground_intensity_properties_by_modality`: Intensity statistics of the foreground regions (all labels except background and ignore label), computed over all training cases. Used by [CT normalization scheme](explanation_normalization.md). - `image_reader_writer`: Name of the image reader/writer class that should be used with this dataset. You might want to change this if, for example, you would like to run inference with files that have a different file format. The class that is named here must be located in nnunetv2.imageio! - `label_manager`: The name of the class that does label handling. Take a look at nnunetv2.utilities.label_handling.LabelManager to see what it does. If you decide to change it, place your version in nnunetv2.utilities.label_handling! - `transpose_forward`: nnU-Net transposes the input data so that the axes with the highest resolution (lowest spacing) come last. This is because the 2D U-Net operates on the trailing dimensions (more efficient slicing due to internal memory layout of arrays). Future work might move this setting to affect only individual configurations. - transpose_backward is what numpy.transpose gets as new axis ordering. - `transpose_backward`: the axis ordering that inverts "transpose_forward" - \[`original_median_shape_after_transp`\]: just here for your information - \[`original_median_spacing_after_transp`\]: just here for your information - \[`plans_name`\]: do not change. Used internally - \[`experiment_planner_used`\]: just here as metadata so that we know what planner originally generated this file - \[`dataset_name`\]: do not change. This is the dataset these plans are intended for ## Local settings Plans also have a `configurations` key in which the actual configurations are stored. `configurations` are again a dictionary, where the keys are the configuration names and the values are the local settings for each configuration. To better understand the components describing the network topology in our plans files, please read section 6.2 in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) (page 13) of our paper! Local settings: - `spacing`: the target spacing used in this configuration - `patch_size`: the patch size used for training this configuration - `data_identifier`: the preprocessed data for this configuration will be saved in nnUNet_preprocessed/DATASET_NAME/_data_identifier_. If you add a new configuration, remember to set a unique data_identifier in order to not create conflicts with other configurations (unless you plan to reuse the data from another configuration, for example as is done in the cascade) - `batch_size`: batch size used for training - `batch_dice`: whether to use batch dice (pretend all samples in the batch are one image, compute dice loss over that) or not (each sample in the batch is a separate image, compute dice loss for each sample and average over samples) - `preprocessor_name`: Name of the preprocessor class used for running preprocessing. Class must be located in nnunetv2.preprocessing.preprocessors - `use_mask_for_norm`: whether to use the nonzero mask for normalization or not (relevant for BraTS and the like, probably False for all other datasets). Interacts with ImageNormalization class - `normalization_schemes`: mapping of channel identifier to ImageNormalization class name. ImageNormalization classes must be located in nnunetv2.preprocessing.normalization. Also see [here](explanation_normalization.md) - `resampling_fn_data`: name of resampling function to be used for resizing image data. resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling - `resampling_fn_data_kwargs`: kwargs for resampling_fn_data - `resampling_fn_probabilities`: name of resampling function to be used for resizing predicted class probabilities/logits. resampling function must be `callable(data: Union[np.ndarray, torch.Tensor], current_spacing, new_spacing, **kwargs)`. It must be located in nnunetv2.preprocessing.resampling - `resampling_fn_probabilities_kwargs`: kwargs for resampling_fn_probabilities - `resampling_fn_seg`: name of resampling function to be used for resizing segmentation maps (integer: 0, 1, 2, 3, etc). resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling - `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg - `network_arch_class_name`: UNet class name, can be used to integrate custom dynamic architectures - `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features are doubled with each downsampling - `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to prevent parameters from exploding too much. - `conv_kernel_sizes`: the convolutional kernel sizes used by nnU-Net in each stage of the encoder. The decoder mirrors the encoder and is therefore not explicitly listed here! The list is as long as `n_conv_per_stage_encoder` has entries - `n_conv_per_stage_encoder`: number of convolutions used per stage (=at a feature map resolution in the encoder) in the encoder. Default is 2. The list has as many entries as the encoder has stages - `n_conv_per_stage_decoder`: number of convolutions used per stage in the decoder. Also see `n_conv_per_stage_encoder` - `num_pool_per_axis`: number of times each of the spatial axes is pooled in the network. Needed to know how to pad image sizes during inference (num_pool = 5 means input must be divisible by 2**5=32) - `pool_op_kernel_sizes`: the pooling kernel sizes (and at the same time strides) for each stage of the encoder - \[`median_image_size_in_voxels`\]: the median size of the images of the training set at the current target spacing. Do not modify this as this is not used. It is just here for your information. Special local settings: - `inherits_from`: configurations can inherit from each other. This makes it easy to add new configurations that only differ in a few local settings from another. If using this, remember to set a new `data_identifier` (if needed)! - `previous_stage`: if this configuration is part of a cascade, we need to know what the previous stage (for example the low resolution configuration) was. This needs to be specified here. - `next_stage`: if this configuration is part of a cascade, we need to know what possible subsequent stages are! This is because we need to export predictions in the correct spacing when running the validation. `next_stage` can either be a string or a list of strings # Examples ## Increasing the batch size for large datasets If your dataset is large the training can benefit from larger batch_sizes. To do this, simply create a new configuration in the `configurations` dict "configurations": { "3d_fullres_bs40": { "inherits_from": "3d_fullres", "batch_size": 40 } } No need to change the data_identifier. `3d_fullres_bs40` will just use the preprocessed data from `3d_fullres`. No need to rerun `nnUNetv2_preprocess` because we can use already existing data (if available) from `3d_fullres`. ## Using custom preprocessors If you would like to use a different preprocessor class then this can be specified as follows: "configurations": { "3d_fullres_my_preprocesor": { "inherits_from": "3d_fullres", "preprocessor_name": MY_PREPROCESSOR, "data_identifier": "3d_fullres_my_preprocesor" } } You need to run preprocessing for this new configuration: `nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_preprocesor` because it changes the preprocessing. Remember to set a unique `data_identifier` whenever you make modifications to the preprocessed data! ## Change target spacing "configurations": { "3d_fullres_my_spacing": { "inherits_from": "3d_fullres", "spacing": [X, Y, Z], "data_identifier": "3d_fullres_my_spacing" } } You need to run preprocessing for this new configuration: `nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_spacing` because it changes the preprocessing. Remember to set a unique `data_identifier` whenever you make modifications to the preprocessed data! ## Adding a cascade to a dataset where it does not exist Hippocampus is small. It doesn't have a cascade. It also doesn't really make sense to add a cascade here but hey for the sake of demonstration we can do that. We change the following things here: - `spacing`: The lowres stage should operate at a lower resolution - we modify the `median_image_size_in_voxels` entry as a guide for what original image sizes we deal with - we set some patch size that is inspired by `median_image_size_in_voxels` - we need to remember that the patch size must be divisible by 2**num_pool in each axis! - network parameters such as kernel sizes, pooling operations are changed accordingly - we need to specify the name of the next stage - we need to add the highres stage This is how this would look like (comparisons with 3d_fullres given as reference): "configurations": { "3d_lowres": { "inherits_from": "3d_fullres", "data_identifier": "3d_lowres" "spacing": [2.0, 2.0, 2.0], # from [1.0, 1.0, 1.0] in 3d_fullres "median_image_size_in_voxels": [18, 25, 18], # from [36, 50, 35] "patch_size": [20, 28, 20], # from [40, 56, 40] "n_conv_per_stage_encoder": [2, 2, 2], # one less entry than 3d_fullres ([2, 2, 2, 2]) "n_conv_per_stage_decoder": [2, 2], # one less entry than 3d_fullres "num_pool_per_axis": [2, 2, 2], # one less pooling than 3d_fullres in each dimension (3d_fullres: [3, 3, 3]) "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], # one less [2, 2, 2] "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], # one less [3, 3, 3] "next_stage": "3d_cascade_fullres" # name of the next stage in the cascade }, "3d_cascade_fullres": { # does not need a data_identifier because we can use the data of 3d_fullres "inherits_from": "3d_fullres", "previous_stage": "3d_lowres" # name of the previous stage } } To better understand the components describing the network topology in our plans files, please read section 6.2 in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) (page 13) of our paper! ================================================ FILE: documentation/extending_nnunet.md ================================================ # Extending nnU-Net We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position in the repository where the thing you intend to change is implemented and start working your way through the code from there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the necessary modifications! Here are some things you might want to read before you start: - Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)! - [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend! - Manual data splits can be defined as described [here](manual_data_splits.md) - You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md) - Read about our support for [region-based training](region_based_training.md) - If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and implements the necessary changes. Head over to our [trainer classes folder](../nnunetv2/training/nnUNetTrainer) for inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer are structured similarly to PyTorch lightning trainers, this should also make things easier! - Integrating new network architectures can be done in two ways: - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function. Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision` as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that! - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether certain patch sizes and topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`. It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation! - Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer classes and ensure support for all of these features! Or raise `NotImplementedError`) [//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md)) ================================================ FILE: documentation/how_to_use_nnunet.md ================================================ ## **2024-04-18 UPDATE: New residual encoder UNet presets available!** The recommended nnU-Net presets have changed! See [here](resenc_presets.md) how to unlock them! ## How to run nnU-Net on a new dataset Given some dataset, nnU-Net fully automatically configures an entire segmentation pipeline that matches its properties. nnU-Net covers the entire pipeline, from preprocessing to model configuration, model training, postprocessing all the way to ensembling. After running nnU-Net, the trained model(s) can be applied to the test cases for inference. ### Dataset Format nnU-Net expects datasets in a structured format. This format is inspired by the data structure of the [Medical Segmentation Decthlon](http://medicaldecathlon.com/). Please read [this](dataset_format.md) for information on how to set up datasets to be compatible with nnU-Net. **Since version 2 we support multiple image file formats (.nii.gz, .png, .tif, ...)! Read the dataset_format documentation to learn more!** **Datasets from nnU-Net v1 can be converted to V2 by running `nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER OUTPUT_DATASET_NAME`.** Remember that v2 calls datasets DatasetXXX_Name (not Task) where XXX is a 3-digit number. Please provide the **path** to the old task, not just the Task name. nnU-Net V2 doesn't know where v1 tasks were! ### Experiment planning and preprocessing Given a new dataset, nnU-Net will extract a dataset fingerprint (a set of dataset-specific properties such as image sizes, voxel spacings, intensity information etc). This information is used to design three U-Net configurations. Each of these pipelines operates on its own preprocessed version of the dataset. The easiest way to run fingerprint extraction, experiment planning and preprocessing is to use: ```bash nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity ``` Where `DATASET_ID` is the dataset id (duh). We recommend `--verify_dataset_integrity` whenever it's the first time you run this command. This will check for some of the most common error sources! Fingerprint extraction and preprocessing show a progress bar by default. Use `--no_pbar` if you want to suppress it, for example in non-interactive cluster environments. You can also process several datasets at once by giving `-d 1 2 3 [...]`. If you already know what U-Net configuration you need you can also specify that with `-c 3d_fullres` (make sure to adapt -np in this case!). For more information about all the options available to you please run `nnUNetv2_plan_and_preprocess -h`. nnUNetv2_plan_and_preprocess will create a new subfolder in your nnUNet_preprocessed folder named after the dataset. Once the command is completed there will be a dataset_fingerprint.json file as well as a nnUNetPlans.json file for you to look at (in case you are interested!). There will also be subfolders containing the preprocessed data for your UNet configurations. [Optional] If you prefer to keep things separate, you can also use `nnUNetv2_extract_fingerprint`, `nnUNetv2_plan_experiment` and `nnUNetv2_preprocess` (in that order). `nnUNetv2_extract_fingerprint` and `nnUNetv2_preprocess` also support `--no_pbar`. ### Model training #### Overview You pick which configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) should be trained! If you have no idea what performs best on your data, just run all of them and let nnU-Net identify the best one. It's up to you! nnU-Net trains all configurations in a 5-fold cross-validation over the training cases. This is 1) needed so that nnU-Net can estimate the performance of each configuration and tell you which one should be used for your segmentation problem and 2) a natural way of obtaining a good model ensemble (average the output of these 5 models for prediction) to boost performance. You can influence the splits nnU-Net uses for 5-fold cross-validation (see [here](manual_data_splits.md)). If you prefer to train a single model on all training cases, this is also possible (see below). **Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net already covers a large part of the input images.** Training models is done with the `nnUNetv2_train` command. The general structure of the command is: ```bash nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h] ``` UNET_CONFIGURATION is a string that identifies the requested U-Net configuration (defaults: 2d, 3d_fullres, 3d_lowres, 3d_cascade_lowres). DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of the 5-fold-cross-validation is trained. nnU-Net stores a checkpoint every 50 epochs. If you need to continue a previous training, just add a `--c` to the training command. IMPORTANT: If you plan to use `nnUNetv2_find_best_configuration` (see below) add the `--npz` flag. This makes nnU-Net save the softmax outputs during the final validation. They are needed for that. Exported softmax predictions are very large and therefore can take up a lot of disk space, which is why this is not enabled by default. If you ran initially without the `--npz` flag but now require the softmax predictions, simply rerun the validation with: ```bash nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz ``` You can specify the device nnU-net should use by using `-device DEVICE`. DEVICE can only be cpu, cuda or mps. If you have multiple GPUs, please select the gpu id using `CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...]` (requires device to be cuda). See `nnUNetv2_train -h` for additional options. ### 2D U-Net For FOLD in [0, 1, 2, 3, 4], run: ```bash nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD [--npz] ``` ### 3D full resolution U-Net For FOLD in [0, 1, 2, 3, 4], run: ```bash nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [--npz] ``` ### 3D U-Net cascade #### 3D low resolution U-Net For FOLD in [0, 1, 2, 3, 4], run: ```bash nnUNetv2_train DATASET_NAME_OR_ID 3d_lowres FOLD [--npz] ``` #### 3D full resolution U-Net For FOLD in [0, 1, 2, 3, 4], run: ```bash nnUNetv2_train DATASET_NAME_OR_ID 3d_cascade_fullres FOLD [--npz] ``` **Note that the 3D full resolution U-Net of the cascade requires the five folds of the low resolution U-Net to be completed!** The trained models will be written to the nnUNet_results folder. Each training obtains an automatically generated output folder name: nnUNet_results/DatasetXXX_MYNAME/TRAINER_CLASS_NAME__PLANS_NAME__CONFIGURATION/FOLD For Dataset002_Heart (from the MSD), for example, this looks like this: nnUNet_results/ ├── Dataset002_Heart │── nnUNetTrainer__nnUNetPlans__2d │ ├── fold_0 │ ├── fold_1 │ ├── fold_2 │ ├── fold_3 │ ├── fold_4 │ ├── dataset.json │ ├── dataset_fingerprint.json │ └── plans.json └── nnUNetTrainer__nnUNetPlans__3d_fullres ├── fold_0 ├── fold_1 ├── fold_2 ├── fold_3 ├── fold_4 ├── dataset.json ├── dataset_fingerprint.json └── plans.json Note that 3d_lowres and 3d_cascade_fullres do not exist here because this dataset did not trigger the cascade. In each model training output folder (each of the fold_x folder), the following files will be created: - debug.json: Contains a summary of blueprint and inferred parameters used for training this model as well as a bunch of additional stuff. Not easy to read, but very useful for debugging ;-) - checkpoint_best.pth: checkpoint files of the best model identified during training. Not used right now unless you explicitly tell nnU-Net to use it. - checkpoint_final.pth: checkpoint file of the final model (after training has ended). This is what is used for both validation and inference. - network_architecture.pdf (only if hiddenlayer is installed!): a pdf document with a figure of the network architecture in it. - progress.png: Shows losses, pseudo dice, learning rate and epoch times ofer the course of the training. At the top is a plot of the training (blue) and validation (red) loss during training. Also shows an approximation of the dice (green) as well as a moving average of it (dotted green line). This approximation is the average Dice score of the foreground classes. **It needs to be taken with a big (!) grain of salt** because it is computed on randomly drawn patches from the validation data at the end of each epoch, and the aggregation of TP, FP and FN for the Dice computation treats the patches as if they all originate from the same volume ('global Dice'; we do not compute a Dice for each validation case and then average over all cases but pretend that there is only one validation case from which we sample patches). The reason for this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training. - validation: in this folder are the predicted validation cases after the training has finished. The summary.json file in here contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then the compressed softmax outputs (saved as .npz files) are in here as well. During training it is often useful to watch the progress. We therefore recommend that you have a look at the generated progress.png when running the first training. It will be updated after each epoch. Training times largely depend on the GPU. The smallest GPU we recommend for training is the Nvidia RTX 2080ti. With that all network trainings take less than 2 days. Refer to our [benchmarks](benchmarking.md) to see if your system is performing as expected. ### Using multiple GPUs for training If multiple GPUs are at your disposal, the best way of using them is to train multiple nnU-Net trainings at once, one on each GPU. This is because data parallelism never scales perfectly linearly, especially not with small networks such as the ones used by nnU-Net. Example: ```bash CUDA_VISIBLE_DEVICES=0 nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] & # train on GPU 0 CUDA_VISIBLE_DEVICES=1 nnUNetv2_train DATASET_NAME_OR_ID 2d 1 [--npz] & # train on GPU 1 CUDA_VISIBLE_DEVICES=2 nnUNetv2_train DATASET_NAME_OR_ID 2d 2 [--npz] & # train on GPU 2 CUDA_VISIBLE_DEVICES=3 nnUNetv2_train DATASET_NAME_OR_ID 2d 3 [--npz] & # train on GPU 3 CUDA_VISIBLE_DEVICES=4 nnUNetv2_train DATASET_NAME_OR_ID 2d 4 [--npz] & # train on GPU 4 ... wait ``` **Important: The first time a training is run nnU-Net will extract the preprocessed data into uncompressed numpy arrays for speed reasons! This operation must be completed before starting more than one training of the same configuration! Wait with starting subsequent folds until the first training is using the GPU! Depending on the dataset size and your System this should only take a couple of minutes at most.** If you insist on running DDP multi-GPU training, we got you covered: `nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] -num_gpus X` Again, note that this will be slower than running separate training on separate GPUs. DDP only makes sense if you have manually interfered with the nnU-Net configuration and are training larger models with larger patch and/or batch sizes! Important when using `-num_gpus`: 1) If you train using, say, 2 GPUs but have more GPUs in the system you need to specify which GPUs should be used via CUDA_VISIBLE_DEVICES=0,1 (or whatever your ids are). 2) You cannot specify more GPUs than you have samples in your minibatches. If the batch size is 2, 2 GPUs is the maximum! 3) Make sure your batch size is divisible by the numbers of GPUs you use or you will not make good use of your hardware. In contrast to the old nnU-Net, DDP is now completely hassle free. Enjoy! ### Automatically determine the best configuration Once the desired configurations were trained (full cross-validation) you can tell nnU-Net to automatically identify the best combination for you: ```commandline nnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS ``` `CONFIGURATIONS` hereby is the list of configurations you would like to explore. Per default, ensembling is enabled meaning that nnU-Net will generate all possible combinations of ensembles (2 configurations per ensemble). This requires the .npz files containing the predicted probabilities of the validation set to be present (use `nnUNetv2_train` with `--npz` flag, see above). You can disable ensembling by setting the `--disable_ensembling` flag. See `nnUNetv2_find_best_configuration -h` for more options. nnUNetv2_find_best_configuration will also automatically determine the postprocessing that should be used. Postprocessing in nnU-Net only considers the removal of all but the largest component in the prediction (once for foreground vs background and once for each label/region). Once completed, the command will print to your console exactly what commands you need to run to make predictions. It will also create two files in the `nnUNet_results/DATASET_NAME` folder for you to inspect: - `inference_instructions.txt` again contains the exact commands you need to use for predictions - `inference_information.json` can be inspected to see the performance of all configurations and ensembles, as well as the effect of the postprocessing plus some debug information. ### Run inference Remember that the data located in the input folder must have the file endings as the dataset you trained the model on and must adhere to the nnU-Net naming scheme for image files (see [dataset format](dataset_format.md) and [inference data format](dataset_format_inference.md)!) `nnUNetv2_find_best_configuration` (see above) will print a string to the terminal with the inference commands you need to use. The easiest way to run inference is to simply use these commands. If you wish to manually specify the configuration(s) used for inference, use the following commands: #### Run prediction For each of the desired configurations, run: ``` nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities ``` Only specify `--save_probabilities` if you intend to use ensembling. `--save_probabilities` will make the command save the predicted probabilities alongside of the predicted segmentation masks requiring a lot of disk space. Please select a separate `OUTPUT_FOLDER` for each configuration! Note that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very strongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference. If you wish to make predictions with a single model, train the `all` fold and specify it in `nnUNetv2_predict` with `-f all` #### Ensembling multiple configurations If you wish to ensemble multiple predictions (typically form different configurations), you can do so with the following command: ```bash nnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES ``` You can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were generated by `nnUNetv2_predict`. Again, `nnUNetv2_ensemble -h` will tell you more about additional options. #### Apply postprocessing Finally, apply the previously determined postprocessing to the (ensembled) predictions: ```commandline nnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE ``` `nnUNetv2_find_best_configuration` (or its generated `inference_instructions.txt` file) will tell you where to find the postprocessing file. If not you can just look for it in your results folder (it's creatively named `postprocessing.pkl`). If your source folder is from an ensemble, you also need to specify a `-plans_json` file and a `-dataset_json` file that should be used (for single configuration predictions these are automatically copied from the respective training). You can pick these files from any of the ensemble members. ## How to run inference with pretrained models See [here](run_inference_with_pretrained_models.md) ## How to Deploy and Run Inference with YOUR Pretrained Models To facilitate the use of pretrained models on a different computer for inference purposes, follow these streamlined steps: 1. Exporting the Model: Utilize the `nnUNetv2_export_model_to_zip` function to package your trained model into a .zip file. This file will contain all necessary model files. 2. Transferring the Model: Transfer the .zip file to the target computer where inference will be performed. 3. Importing the Model: On the new PC, use the `nnUNetv2_install_pretrained_model_from_zip` to load the pretrained model from the .zip file. Please note that both computers must have nnU-Net installed along with all its dependencies to ensure compatibility and functionality of the model. [//]: # (## Examples) [//]: # () [//]: # (To get you started we compiled two simple to follow examples:) [//]: # (- run a training with the 3d full resolution U-Net on the Hippocampus dataset. See [here](documentation/training_example_Hippocampus.md).) [//]: # (- run inference with nnU-Net's pretrained models on the Prostate dataset. See [here](documentation/inference_example_Prostate.md).) [//]: # () [//]: # (Usability not good enough? Let us know!) ================================================ FILE: documentation/ignore_label.md ================================================ # Ignore Label The _ignore label_ can be used to mark regions that should be ignored by nnU-Net. This can be used to learn from images where only sparse annotations are available, for example in the form of scribbles or a limited amount of annotated slices. Internally, this is accomplished by using partial losses, i.e. losses that are only computed on annotated pixels while ignoring the rest. Take a look at our [`DC_and_BCE_loss` loss](../nnunetv2/training/loss/compound_losses.py) to see how this is done. During inference (validation and prediction), nnU-Net will always predict dense segmentations. Metric computation in validation is of course only done on annotated pixels. Using sparse annotations can be used to train a model for application to new, unseen images or to autocomplete the provided training cases given the sparse labels. (See our [paper](https://arxiv.org/abs/2403.12834) for more information) Typical use-cases for the ignore label are: - Save annotation time through sparse annotation schemes - Annotation of all or a subset of slices with scribbles (Scribble Supervision) - Dense annotation of a subset of slices - Dense annotation of chosen patches/cubes within an image - Coarsly masking out faulty segmentations in the reference segmentations - Masking areas for other reasons If you are using nnU-Net's ignore label, please cite the following paper in addition to the original nnU-net paper: ``` Gotkowski, K., Lüth, C., Jäger, P. F., Ziegler, S., Krämer, L., Denner, S., Xiao, S., Disch, N., H., K., & Isensee, F. (2024). Embarrassingly Simple Scribble Supervision for 3D Medical Segmentation. ArXiv. /abs/2403.12834 ``` ## Usecases ### Scribble Supervision Scribbles are free-form drawings to coarsly annotate an image. As we have demonstrated in our recent [paper](https://arxiv.org/abs/2403.12834), nnU-Net's partial loss implementation enables state-of-the-art learning from partially annotated data and even surpasses many purpose-built methods for learning from scribbles. As a starting point, for each image slice and each class (including background), an interior and a border scribble should be generated: - Interior Scribble: A scribble placed randomly within the class interior of a class instance - Border Scribble: A scribble roughly delineating a small part of the class border of a class instance An example of such scribble annotations is depicted in Figure 1 and an animation in Animation 1. Depending on the availability of data and their variability it is also possible to only annotated a subset of selected slices.

Figure 1: Examples of segmentation types with (A) depicting a dense segmentation and (B) a scribble segmentation.

Animation 1: Depiction of a dense segmentation and a scribble annotation. Background scribbles have been excluded for better visualization.

### Dense annotation of a subset of slices Another form of sparse annotation is the dense annotation of a subset of slices. These slices should be selected by the user either randomly, based on visual class variation between slices or in an active learning setting. An example with only 10% of slices annotated is depicted in Figure 2.

Figure 2: Examples of a dense annotation of a subset of slices. The ignored areas are shown in red.

## Usage within nnU-Net Usage of the ignore label in nnU-Net is straightforward and only requires the definition of an _ignore_ label in the _dataset.json_. This ignore label MUST be the highest integer label value in the segmentation. Exemplary, given the classes background and two foreground classes, then the ignore label must have the integer 3. The ignore label must be named _ignore_ in the _dataset.json_. Given the BraTS dataset as an example the labels dict of the _dataset.json_ must look like this: ```python ... "labels": { "background": 0, "edema": 1, "non_enhancing_and_necrosis": 2, "enhancing_tumor": 3, "ignore": 4 }, ... ``` Of course, the ignore label is compatible with [region-based training](region_based_training.md): ```python ... "labels": { "background": 0, "whole_tumor": (1, 2, 3), "tumor_core": (2, 3), "enhancing_tumor": 3, # or (3, ) "ignore": 4 }, "regions_class_order": (1, 2, 3), # don't declare ignore label here! It is not predicted ... ``` Then use the dataset as you would any other. Remember that nnU-Net runs a cross-validation. Thus, it will also evaluate on your partially annotated data. This will of course work! If you wish to compare different sparse annotation strategies (through simulations for example), we recommend evaluating on densely annotated images by running inference and then using `nnUNetv2_evaluate_folder` or `nnUNetv2_evaluate_simple`. ================================================ FILE: documentation/installation_instructions.md ================================================ # System requirements ## Operating System nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box! ## Hardware requirements We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D convolutions, so you might have to use the CPU on those devices). ### Hardware requirements for Training We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2). For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads) are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of input channels and target structures. Plus, the faster the GPU, the better the CPU should be! ### Hardware Requirements for inference Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However, inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least 4 GB of available (unused) VRAM. ### Example hardware configurations Example workstation configurations for training: - CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well. - GPU: RTX 3090 or RTX 4090 - RAM: 64GB - Storage: SSD (M.2 PCIe Gen 3 or better!) Example Server configuration for training: - CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100! - GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power) - RAM: 1 TB - Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage (nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously) ### Setting the correct number of Workers for data augmentation (training only) Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`). Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes. # Installation instructions We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to compile PyTorch from source (see below), you will need to use conda instead of pip. Use a recent version of Python! 3.9 or newer is guaranteed to work! **nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.** 1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please install the latest version with support for your hardware (cuda, mps, cpu). **DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider [compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!). 2) Install nnU-Net depending on your use case: 1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running **inference with pretrained models**: ```pip install nnunetv2``` 2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you can modify it as needed): ```bash git clone https://github.com/MIC-DKFZ/nnUNet.git cd nnUNet pip install -e . ``` 3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to set a few environment variables. Please follow the instructions [here](setting_up_paths.md). 4) (OPTIONAL) Install [hiddenlayer](https://github.com/waleedka/hiddenlayer). hiddenlayer enables nnU-net to generate plots of the network topologies it generates (see [Model training](how_to_use_nnunet.md#model-training)). To install hiddenlayer, run the following command: ```bash pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git ``` Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for easy identification. Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this environment must be activated when executing the commands. You can see what scripts/functions are executed by checking the project.scripts in the [pyproject.toml](../pyproject.toml) file. All nnU-Net commands have a `-h` option which gives information on how to use them. ================================================ FILE: documentation/manual_data_splits.md ================================================ # How to generate custom splits in nnU-Net Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification. Fear not, for nnU-Net has got you covered (it really can do anything <3). The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split, manually creating a split file that will then be recognized and used is the way to go! The split file is located in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first populate this folder by running `nnUNetv2_plan_and_preproccess`. Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002 file as an example: ```commandline In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json In [2]: splits = load_json('splits_final.json') In [3]: len(splits) Out[3]: 5 In [4]: splits[0].keys() Out[4]: dict_keys(['train', 'val']) In [5]: len(splits[0]['train']) Out[5]: 16 In [6]: len(splits[0]['val']) Out[6]: 4 In [7]: print(splits[0]) {'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'], 'val': ['la_007', 'la_016', 'la_021', 'la_024']} ``` If you are still not sure what splits are supposed to look like, simply download some reference dataset from the [Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect the .json file with your text editor of choice! In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as `splits_final.json` in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual. ================================================ FILE: documentation/pretraining_and_finetuning.md ================================================ # Pretraining with nnU-Net ## Intro So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset and then use the final network weights as initialization for your target dataset. As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not possible to simply take the network weights from some dataset and then reuse them for another. Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and how the resulting weights can then be used for initialization. ### Terminology Throughout this README we use the following terminology: - `pretraining dataset` is the dataset you intend to run the pretraining on - `finetuning dataset` is the dataset you are interested in; the one you wish to fine tune on ## Training on the pretraining dataset In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are only interested in the finetuning dataset, we first need to run experiment planning (and preprocessing) for it: ```bash nnUNetv2_plan_and_preprocess -d FINETUNING_DATASET ``` Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available: ```bash nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET ``` Now we can take the plans from the finetuning dataset and transfer it to the pretraining dataset: ```bash nnUNetv2_move_plans_between_datasets -s FINETUNING_DATASET -t PRETRAINING_DATASET -sp FINETUNING_PLANS_IDENTIFIER -tp PRETRAINING_PLANS_IDENTIFIER ``` `FINETUNING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in nnUNetv2_plan_and_preprocess. For `PRETRAINING_PLANS_IDENTIFIER` we recommend you set something custom in order to not overwrite default plans. Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not work well (but it could, depending on the schemes!). Note on CT normalization: Yes, also the clip values, mean and std are transferred! Now you can run the preprocessing on the pretraining dataset: ```bash nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name PRETRAINING_PLANS_IDENTIFIER ``` And run the training as usual: ```bash nnUNetv2_train PRETRAINING_DATASET CONFIG all -p PRETRAINING_PLANS_IDENTIFIER ``` Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data. ## Using pretrained weights Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model: ```bash nnUNetv2_train FINETUNING_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT ``` Specify the checkpoint in PATH_TO_CHECKPOINT. When loading pretrained weights, all layers except the segmentation layers will be used! So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation heads or shorter training time. ================================================ FILE: documentation/region_based_training.md ================================================ # Region-based training ## What is this about? In some segmentation tasks, most prominently the [Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric will be computed) are different from the labels provided in the training data. This is the case because for some clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor). The figure shows an example BraTS case along with label-based representation of the task (top) and region-based representation (bottom). The challenge evaluation is done on the regions. As we have shown in our [BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those overlapping areas over the individual labels yields better scoring models! ## What can nnU-Net do? nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training. Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs + substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be segmented. The first target region could thus be the entire liver (including the substructures), while the remaining targets are the individual substructues. Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output! Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps for nnU-Net to work (see below how this is done). ## How do you use it? When declaring the labels in the `dataset.json` file, BraTS would typically look like this: ```python ... "labels": { "background": 0, "edema": 1, "non_enhancing_and_necrosis": 2, "enhancing_tumor": 3 }, ... ``` (we use different int values than the challenge because nnU-Net needs consecutive integers!) This representation corresponds to the upper row in the figure above. For region-based training, the labels need to be changed to the following: ```python ... "labels": { "background": 0, "whole_tumor": [1, 2, 3], "tumor_core": [2, 3], "enhancing_tumor": 3 # or [3] }, "regions_class_order": [1, 2, 3], ... ``` This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map. It essentially just tells nnU-Net what labels to place for which region in what order. The length of the list here needs to be the same as the number of regions (excl background). Each element in the list corresponds to the label that is placed instead of the region into the final segmentation. Later entries will overwrite earlier ones! Concretely, for the example given here, nnU-Net will firstly place the label 1 (edema) where the 'whole_tumor' region was predicted, then place the label 2 (non-enhancing tumor and necrosis) where the "tumor_core" was predicted and finally place the label 3 in the predicted 'enhancing_tumor' area. With each step, part of the previously set pixels will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions (like whole tumor etc) first, followed by substructures. **IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are declared ("place label X in the first region") you need to make sure that this order is not perturbed! When automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set `sort_keys=False` in `json.dump()`!!! nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels! That's all. Easy, huh? ================================================ FILE: documentation/resenc_presets.md ================================================ # Residual Encoder Presets in nnU-Net When using these presets, please cite our recent paper on the need for rigorous validation in 3D medical image segmentation: > Isensee, F.* , Wald, T.* , Ulrich, C.* , Baumgartner, M.* , Roy, S., Maier-Hein, K., Jaeger, P. (2024). nnU-Net Revisited: A Call for Rigorous Validation in 3D Medical Image Segmentation. arXiv preprint arXiv:2404.09556. *: shared first authors\ : shared last authors [PAPER LINK](https://arxiv.org/pdf/2404.09556.pdf) Residual Encoder UNets have been supported by nnU-Net since our participation in KiTS2019, but have flown under the radar. This is bound to change with our new nnUNetResEncUNet presets :raised_hands:! Especially on large datasets such as KiTS2023 and AMOS2022 they offer improved segmentation performance! | | BTCV | ACDC | LiTS | BraTS | KiTS | AMOS | VRAM | RT | Arch. | nnU | |------------------------|-------|-------|-------|-------|-------|-------|-------|-----|-------|-----| | | n=30 | n=200 | n=131 | n=1251| n=489 | n=360 | | | | | | nnU-Net (org.) [1] | 83.08 | 91.54 | 80.09 | 91.24 | 86.04 | 88.64 | 7.70 | 9 | CNN | Yes | | nnU-Net ResEnc M | 83.31 | 91.99 | 80.75 | 91.26 | 86.79 | 88.77 | 9.10 | 12 | CNN | Yes | | nnU-Net ResEnc L | 83.35 | 91.69 | 81.60 | 91.13 | 88.17 | 89.41 | 22.70 | 35 | CNN | Yes | | nnU-Net ResEnc XL | 83.28 | 91.48 | 81.19 | 91.18 | 88.67 | 89.68 | 36.60 | 66 | CNN | Yes | | MedNeXt L k3 [2] | 84.70 | 92.65 | 82.14 | 91.35 | 88.25 | 89.62 | 17.30 | 68 | CNN | Yes | | MedNeXt L k5 [2] | 85.04 | 92.62 | 82.34 | 91.50 | 87.74 | 89.73 | 18.00 | 233 | CNN | Yes | | STU-Net S [3] | 82.92 | 91.04 | 78.50 | 90.55 | 84.93 | 88.08 | 5.20 | 10 | CNN | Yes | | STU-Net B [3] | 83.05 | 91.30 | 79.19 | 90.85 | 86.32 | 88.46 | 8.80 | 15 | CNN | Yes | | STU-Net L [3] | 83.36 | 91.31 | 80.31 | 91.26 | 85.84 | 89.34 | 26.50 | 51 | CNN | Yes | | SwinUNETR [4] | 78.89 | 91.29 | 76.50 | 90.68 | 81.27 | 83.81 | 13.10 | 15 | TF | Yes | | SwinUNETRV2 [5] | 80.85 | 92.01 | 77.85 | 90.74 | 84.14 | 86.24 | 13.40 | 15 | TF | Yes | | nnFormer [6] | 80.86 | 92.40 | 77.40 | 90.22 | 75.85 | 81.55 | 5.70 | 8 | TF | Yes | | CoTr [7] | 81.95 | 90.56 | 79.10 | 90.73 | 84.59 | 88.02 | 8.20 | 18 | TF | Yes | | No-Mamba Base | 83.69 | 91.89 | 80.57 | 91.26 | 85.98 | 89.04 | 12.0 | 24 | CNN | Yes | | U-Mamba Bot [8] | 83.51 | 91.79 | 80.40 | 91.26 | 86.22 | 89.13 | 12.40 | 24 | Mam | Yes | | U-Mamba Enc [8] | 82.41 | 91.22 | 80.27 | 90.91 | 86.34 | 88.38 | 24.90 | 47 | Mam | Yes | | A3DS SegResNet [9,11] | 80.69 | 90.69 | 79.28 | 90.79 | 81.11 | 87.27 | 20.00 | 22 | CNN | No | | A3DS DiNTS [10, 11] | 78.18 | 82.97 | 69.05 | 87.75 | 65.28 | 82.35 | 29.20 | 16 | CNN | No | | A3DS SwinUNETR [4, 11] | 76.54 | 82.68 | 68.59 | 89.90 | 52.82 | 85.05 | 34.50 | 9 | TF | No | Results taken from our paper (see above), reported values are Dice scores computed over 5-fold cross-validation on each dataset. All models trained from scratch. RT: training run time (measured on 1x Nvidia A100 PCIe 40GB)\ VRAM: GPU VRAM used during training, as reported by nvidia-smi\ Arch.: CNN = convolutional neural network; TF = transformer; Mam = Mamba\ nnU: whether the architectrue was integrated and tested with the nnU-Net framework (either by us or the original authors) ## How to use the new presets We offer three new presets, each targeted for a different GPU VRAM and compute budget: - **nnU-Net ResEnc M**: similar GPU budget to the standard UNet configuration. Best suited for GPUs with 9-11GB VRAM. Training time: ~12h on A100 - **nnU-Net ResEnc L**: requires a GPU with 24GB VRAM. Training time: ~35h on A100 - **nnU-Net ResEnc XL**: requires a GPU with 40GB VRAM. Training time: ~66h on A100 ### **:point_right: We recommend **nnU-Net ResEnc L** as the new default nnU-Net configuration! :point_left:** The new presets are available as follows ((M/L/XL) = pick one!): 1. Specify the desired configuration when running experiment planning and preprocessing: `nnUNetv2_plan_and_preprocess -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)`. These planners use the same preprocessed data folder as the standard 2d and 3d_fullres configurations since the preprocessed data is identical. Only the 3d_lowres differs and will be saved in a different folder to allow all configurations to coexist! If you are only planning to run 3d_fullres/2d and you already have this data preprocessed, you can just run `nnUNetv2_plan_experiment -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)` to avoid preprocessing again! 2. Now, just specify the correct plans when running `nnUNetv2_train`, `nnUNetv2_predict` etc. The interface is consistent across all nnU-Net commands: `-p nnUNetResEncUNet(M/L/XL)Plans` Training results for the new presets will be stored in a dedicated folder and will not overwrite standard nnU-Net results! So don't be afraid to give it a go! ## Scaling ResEnc nnU-Net beyond the Presets The presets differ from `ResEncUNetPlanner` in two ways: - They set new default values for `gpu_memory_target_in_gb` to target the respective VRAM consumptions - They remove the batch size cap of 0.05 (= previously one batch could not cover mode pixels than 5% of the entire dataset, not it can be arbitrarily large) The presets are merely there to make life easier, and to provide standardized configurations people can benchmark with. You can easily adapt the GPU memory target to match your GPU, and to scale beyond 40GB of GPU memory. Here is an example for how to scale to 80GB VRAM on Dataset003_Liver: `nnUNetv2_plan_experiment -d 3 -pl nnUNetPlannerResEncM -gpu_memory_target 80 -overwrite_plans_name nnUNetResEncUNetPlans_80G` Just use `-p nnUNetResEncUNetPlans_80G` moving forward as outlined above! Running the example above will yield a warning ("You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb"). This warning can be ignored here. **Always change the plans identifier with `-overwrite_plans_name NEW_PLANS_NAME` when messing with the VRAM target in order to not overwrite preset plans!** Why not use `ResEncUNetPlanner` -> because that one still has the 5% cap in place! ### Scaling to multiple GPUs When scaling to multiple GPUs, do not just specify the combined amount of VRAM to `nnUNetv2_plan_experiment` as this may result in patch sizes that are too large to be processed by individual GPUs. It is best to let this command run for the VRAM budget of one GPU, and then manually edit the plans file to increase the batch size. You can use [configuration inheritance](explanation_plans_files.md). In the configurations dictionary of the generated plans JSON file, add the following entry: ```json "3d_fullres_bsXX": { "inherits_from": "3d_fullres", "batch_size": XX }, ``` Where XX is the new batch size. If 3d_fullres has a batch size of 2 for one GPU and you are planning to scale to 8 GPUs, make the new batch size 2x8=16! You can then train the new configuration using nnU-Net's multi-GPU settings: ```bash nnUNetv2_train DATASETID 3d_fullres_bsXX FOLD -p nnUNetResEncUNetPlans_80G -num_gpus 8 ``` ## Proposing a new segmentation method? Benchmark the right way! When benchmarking new segmentation methods against nnU-Net, we encourage to benchmark against the residual encoder variants. For a fair comparison, pick the variant that most closely matches the GPU memory and compute requirements of your method! ## References [1] Isensee, Fabian, et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nature methods 18.2 (2021): 203-211.\ [2] Roy, Saikat, et al. "Mednext: transformer-driven scaling of convnets for medical image segmentation." International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\ [3] Huang, Ziyan, et al. "Stu-net: Scalable and transferable medical image segmentation models empowered by large-scale supervised pre-training." arXiv preprint arXiv:2304.06716 (2023).\ [4] Hatamizadeh, Ali, et al. "Swin unetr: Swin transformers for semantic segmentation of brain tumors in mri images." International MICCAI Brainlesion Workshop. Cham: Springer International Publishing, 2021.\ [5] He, Yufan, et al. "Swinunetr-v2: Stronger swin transformers with stagewise convolutions for 3d medical image segmentation." International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\ [6] Zhou, Hong-Yu, et al. "nnformer: Interleaved transformer for volumetric segmentation." arXiv preprint arXiv:2109.03201 (2021).\ [7] Xie, Yutong, et al. "Cotr: Efficiently bridging cnn and transformer for 3d medical image segmentation." Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part III 24. Springer International Publishing, 2021.\ [8] Ma, Jun, Feifei Li, and Bo Wang. "U-mamba: Enhancing long-range dependency for biomedical image segmentation." arXiv preprint arXiv:2401.04722 (2024).\ [9] Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries: 4th International Workshop, BrainLes 2018, Held in Conjunction with MICCAI 2018, Granada, Spain, September 16, 2018, Revised Selected Papers, Part II 4. Springer International Publishing, 2019.\ [10] He, Yufan, et al. "Dints: Differentiable neural network topology search for 3d medical image segmentation." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\ [11] Auto3DSeg, MONAI 1.3.0, [LINK](https://github.com/Project-MONAI/tutorials/tree/ed8854fa19faa49083f48abf25a2c30ab9ac1c6b/auto3dseg) ================================================ FILE: documentation/run_inference_with_pretrained_models.md ================================================ # How to run inference with pretrained models **Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new version. But honestly, you already have a fully trained model with which you can run inference (in v1), so just continue using that! Not yet available for V2 :-( If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam! ================================================ FILE: documentation/set_environment_variables.md ================================================ # How to set environment variables nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained models are. Depending on the operating system, these environment variables need to be set in different ways. Variables can either be set permanently (recommended!) or you can decide to set them every time you call nnU-Net. # Linux & MacOS ## Permanent Locate the `.bashrc` file in your home folder and add the following lines to the bottom: ```bash export nnUNet_raw="/media/fabian/nnUNet_raw" export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" export nnUNet_results="/media/fabian/nnUNet_results" ``` (Of course you need to adapt the paths to the actual folders you intend to use). If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`. ## Temporary Just execute the following lines whenever you run nnU-Net: ```bash export nnUNet_raw="/media/fabian/nnUNet_raw" export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" export nnUNet_results="/media/fabian/nnUNet_results" ``` (Of course you need to adapt the paths to the actual folders you intend to use). Important: These variables will be deleted if you close your terminal! They will also only apply to the current terminal window and DO NOT transfer to other terminals! Alternatively you can also just prefix them to your nnU-Net commands: `nnUNet_results="/media/fabian/nnUNet_results" nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" nnUNetv2_train[...]` ## Verify that environment parameters are set You can always execute `echo ${nnUNet_raw}` etc to print the environment variables. This will return an empty string if they were not set. # Windows Useful links: - [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.) - [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable) ## Permanent See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable). Or read about setx (command prompt). ## Temporary Just execute the following before you run nnU-Net: (PowerShell) ```PowerShell $Env:nnUNet_raw = "C:/Users/fabian/nnUNet_raw" $Env:nnUNet_preprocessed = "C:/Users/fabian/nnUNet_preprocessed" $Env:nnUNet_results = "C:/Users/fabian/nnUNet_results" ``` (Command Prompt) ```Command Prompt set nnUNet_raw=C:/Users/fabian/nnUNet_raw set nnUNet_preprocessed=C:/Users/fabian/nnUNet_preprocessed set nnUNet_results=C:/Users/fabian/fabian/nnUNet_results ``` (Of course you need to adapt the paths to the actual folders you intend to use). Important: These variables will be deleted if you close your session! They will also only apply to the current window and DO NOT transfer to other sessions! ## Verify that environment parameters are set Printing in Windows works differently depending on the environment you are in: PowerShell: `echo $Env:[variable_name]` Command Prompt: `echo %variable_name%` ================================================ FILE: documentation/setting_up_paths.md ================================================ # Setting up Paths nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored. To use the full functionality of nnU-Net, the following three environment variables must be set: 1) `nnUNet_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique) dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md). Example tree structure: ``` nnUNet_raw/Dataset001_NAME1 ├── dataset.json ├── imagesTr │   ├── ... ├── imagesTs │   ├── ... └── labelsTr ├── ... nnUNet_raw/Dataset002_NAME2 ├── dataset.json ├── imagesTr │   ├── ... ├── imagesTs │   ├── ... └── labelsTr ├── ... ``` 2) `nnUNet_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from this folder during training. It is important that this folder is located on a drive with low access latency and high throughput (such as a nvme SSD (PCIe gen 3 is sufficient)). 3) `nnUNet_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this is where it will save them. ### How to set environment variables See [here](set_environment_variables.md). ================================================ FILE: documentation/tldr_migration_guide_from_v1.md ================================================ # TLDR Migration Guide from nnU-Net V1 - nnU-Net V2 can be installed simultaneously with V1. They won't get in each other's way - The environment variables needed for V2 have slightly different names. Read [this](setting_up_paths.md). - nnU-Net V2 datasets are called DatasetXXX_NAME. Not Task. - Datasets have the same structure (imagesTr, labelsTr, dataset.json) but we now support more [file types](dataset_format.md#supported-file-formats). The dataset.json is simplified. Use `generate_dataset_json` from nnunetv2.dataset_conversion.generate_dataset_json.py. - Careful: labels are now no longer declared as value:name but name:value. This has to do with [hierarchical labels](region_based_training.md). - nnU-Net v2 commands start with `nnUNetv2...`. They work mostly (but not entirely) the same. Just use the `-h` option. - You can transfer your V1 raw datasets to V2 with `nnUNetv2_convert_old_nnUNet_dataset`. You cannot transfer trained models. Continue to use the old nnU-Net Version for making inference with those. - These are the commands you are most likely to be using (in that order) - `nnUNetv2_plan_and_preprocess`. Example: `nnUNetv2_plan_and_preprocess -d 2` - `nnUNetv2_train`. Example: `nnUNetv2_train 2 3d_fullres 0` - `nnUNetv2_find_best_configuration`. Example: `nnUNetv2_find_best_configuration 2 -c 2d 3d_fullres`. This command will now create a `inference_instructions.txt` file in your `nnUNet_preprocessed/DatasetXXX_NAME/` folder which tells you exactly how to do inference. - `nnUNetv2_predict`. Example: `nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -c 3d_fullres -d 2` - `nnUNetv2_apply_postprocessing` (see inference_instructions.txt) ================================================ FILE: nnunetv2/__init__.py ================================================ ================================================ FILE: nnunetv2/batch_running/__init__.py ================================================ ================================================ FILE: nnunetv2/batch_running/benchmarking/__init__.py ================================================ ================================================ FILE: nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py ================================================ if __name__ == '__main__': """ This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler! """ gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB', 'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB'] datasets = [2, 3, 4, 5] trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] plans = ['nnUNetPlans'] configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] num_gpus = 1 benchmark_configurations = {d: configs for d in datasets} exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\"" resources = "-R \"tensorcore\"" queue = "-q gpu" preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && " train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train' folds = (0, ) use_these_modules = { tr: plans for tr in trainers } additional_arguments = f' -num_gpus {num_gpus}' # '' output_file = "/home/isensee/deleteme.txt" with open(output_file, 'w') as f: for g in gpu_models: gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}" for tr in use_these_modules.keys(): for p in use_these_modules[tr]: for dataset in benchmark_configurations.keys(): for config in benchmark_configurations[dataset]: for fl in folds: command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' if additional_arguments is not None and len(additional_arguments) > 0: command += f' {additional_arguments}' f.write(f'{command}\"\n') ================================================ FILE: nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py ================================================ from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.paths import nnUNet_results from nnunetv2.utilities.file_path_utilities import get_output_folder if __name__ == '__main__': trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] datasets = [2, 3, 4, 5] plans = ['nnUNetPlans'] configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] output_file = join(nnUNet_results, 'benchmark_results.csv') torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" # cudnn_version = 8700 # 8302 # num_gpus = 1 unique_gpus = set() # collect results in the most janky way possible. Amazing coding skills! all_results = {} for tr in trainers: all_results[tr] = {} for p in plans: all_results[tr][p] = {} for c in configs: all_results[tr][p][c] = {} for d in datasets: dataset_name = maybe_convert_to_dataset_name(d) output_folder = get_output_folder(dataset_name, tr, p, c, fold=0) expected_benchmark_file = join(output_folder, 'benchmark_result.json') all_results[tr][p][c][d] = {} if isfile(expected_benchmark_file): # filter results for what we want results = [i for i in load_json(expected_benchmark_file).values() if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and i['torch_version'] == torch_version] for r in results: all_results[tr][p][c][d][r['gpu_name']] = r unique_gpus.add(r['gpu_name']) # haha. Fuck this. Collect GPUs in the code above. # unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]]) unique_gpus = list(unique_gpus) unique_gpus.sort() with open(output_file, 'w') as f: f.write('Dataset,Trainer,Plans,Config') for g in unique_gpus: f.write(f",{g}") f.write("\n") for d in datasets: for tr in trainers: for p in plans: for c in configs: gpu_results = [] for g in unique_gpus: if g in all_results[tr][p][c][d].keys(): gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2)) else: gpu_results.append("MISSING") # skip if all are missing if all([i == 'MISSING' for i in gpu_results]): continue f.write(f"{d},{tr},{p},{c}") for g in gpu_results: f.write(f",{g}") f.write("\n") f.write("\n") ================================================ FILE: nnunetv2/batch_running/collect_results_custom_Decathlon.py ================================================ from typing import Tuple import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.evaluation.evaluate_predictions import load_summary_json from nnunetv2.paths import nnUNet_results from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id from nnunetv2.utilities.file_path_utilities import get_output_folder def collect_results(trainers: dict, datasets: List, output_file: str, configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), folds=tuple(np.arange(5))): results_dirs = (nnUNet_results,) datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] with open(output_file, 'w') as f: for i, d in zip(datasets, datasets_names): for c in configurations: for module in trainers.keys(): for plans in trainers[module]: for r in results_dirs: expected_output_folder = get_output_folder(d, module, plans, c) if isdir(expected_output_folder): results_folds = [] f.write(f"{d},{c},{module},{plans},{r}") for fl in folds: expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) expected_summary_file = join(expected_output_folder_fold, "validation", "summary.json") if not isfile(expected_summary_file): print('expected output file not found:', expected_summary_file) f.write(",") results_folds.append(np.nan) else: foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ 'Dice'] results_folds.append(foreground_mean) f.write(f",{foreground_mean:02.4f}") f.write(f",{np.nanmean(results_folds):02.4f}\n") def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): txt = np.loadtxt(input_file, dtype=str, delimiter=',') num_folds = txt.shape[1] - 6 valid_configs = {} for d in datasets: if isinstance(d, int): d = maybe_convert_to_dataset_name(d) configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) valid_configs[d] = [i for i in configs_in_txt if i in configs] assert max(folds) < num_folds with open(output_file, 'w') as f: f.write("name") for d in valid_configs.keys(): for c in valid_configs[d]: f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) f.write(',mean\n') valid_entries = txt[:, 4] == nnUNet_results for t in trainers.keys(): trainer_locs = valid_entries & (txt[:, 2] == t) for pl in trainers[t]: f.write(f"{t}__{pl}") trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) r = [] for d in valid_configs.keys(): trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) for v in valid_configs[d]: trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) if np.any(trainer_plan_d_config_locs): # we cannot have more than one row assert np.sum(trainer_plan_d_config_locs) == 1 # now check that we have all folds selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] fold_results = selected_row[[i + 5 for i in folds]] if '' in fold_results: print('missing fold in', t, pl, d, v) f.write(",nan") r.append(np.nan) else: mean_dice = np.mean([float(i) for i in fold_results]) f.write(f",{mean_dice:02.4f}") r.append(mean_dice) else: print('missing:', t, pl, d, v) f.write(",nan") r.append(np.nan) f.write(f",{np.mean(r):02.4f}\n") if __name__ == '__main__': use_these_trainers = { 'nnUNetTrainer': ('nnUNetResEncUNetLPlans', ), } all_results_file= join(nnUNet_results, 'customDecResults.csv') datasets = [3, 5, 8, 10, 17, 27, 55, 220, 223, 226, 219] # datasets = [3, 4, 5, 8, 10, 17, 27, 55, 220, 223] collect_results(use_these_trainers, datasets, all_results_file) # folds = (0, 1, 2, 3, 4) # configs = ("2d", ) # output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') # summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) folds = (0, ) configs = ("3d_fullres", ) output_file = join(nnUNet_results, 'summary_fold0.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) ================================================ FILE: nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize from nnunetv2.paths import nnUNet_results if __name__ == '__main__': use_these_trainers = { 'nnUNetTrainer': ('nnUNetPlans', ), } all_results_file = join(nnUNet_results, 'hrnet_results.csv') datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] collect_results(use_these_trainers, datasets, all_results_file) folds = (0, ) configs = ('2d', ) output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) ================================================ FILE: nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py ================================================ from copy import deepcopy import numpy as np def merge(dict1, dict2): keys = np.unique(list(dict1.keys()) + list(dict2.keys())) keys = np.unique(keys) res = {} for k in keys: all_configs = [] if dict1.get(k) is not None: all_configs += list(dict1[k]) if dict2.get(k) is not None: all_configs += list(dict2[k]) if len(all_configs) > 0: res[k] = tuple(np.unique(all_configs)) return res if __name__ == "__main__": # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of # datasets for evaluation and future development configurations_all = { 2: ("3d_fullres", "2d"), 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 4: ("2d", "3d_fullres"), 5: ("2d", "3d_fullres"), 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 24: ("2d", "3d_fullres"), 27: ("2d", "3d_fullres"), 38: ("2d", "3d_fullres"), 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 137: ("2d", "3d_fullres"), 220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 223: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 219: ("2d", "3d_fullres"), 226: ("2d", "3d_fullres"), } configurations_3d_fr_only = { i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] } configurations_3d_c_only = { i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] } configurations_3d_lr_only = { i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] } configurations_2d_only = { i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] } num_gpus = 1 exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" resources = "" gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=23G"#gmodel=NVIDIAA100_PCIE_40GB" queue = "-q gpu-pro" preamble = "\". /home/isensee/env_loading_scripts/continuous_performance_monitoring/load_env_torch280.sh && " # -L /bin/bash train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/results_nnUNet_master nnUNetv2_train' folds = (0, ) # use_this = configurations_2d_only use_this = configurations_3d_fr_only # use_this = merge(use_this, configurations_3d_c_only) datasets = [3, 5, 8, 10, 17, 27, 55, 219, 220, 223, 226] use_this = {i: use_this[i] for i in datasets} use_these_modules = { # 'nnUNetTrainer_newSpatialAug': ('nnUNetPlans',), # 'nnUNetTrainerBN': ('nnUNetPlans',), # 'nnUNetTrainer_newSpatialAug_withElDef_noPref': ('nnUNetPlans',), # 'nnUNetTrainer': ('nnUNetConvNextEncUNetPlans_smallks_and_shallow',), # 'nnUNetTrainer_convnextenc_regularconvblock': ('nnUNetPlans_convnext', 'nnUNetConvNextEncUNetPlans_smallks_and_shallow'), # 'nnUNetTrainer_newSpatialAug_withElDef2': ('nnUNetPlans',), # 'nnUNetTrainer_newSpatialAug_withElDef3': ('nnUNetPlans',), # 'nnUNetTrainer_newSpatialAug_noPref': ('nnUNetPlans',), # 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), # 'nnUNetTrainerAdamW_WDe2': ('nnUNetPlans',), # 'nnUNetTrainerUMambaBot': ('nnUNetPlans',), # 'nnUNetTrainerUMambaEnc': ('nnUNetPlans',), # 'nnUNetTrainer_fasterDA': ('nnUNetPlans', 'nnUNetResEncUNetLPlans'), # 'nnUNetTrainer_noDummy2DDA': ('nnUNetResEncUNetMPlans', ), 'nnUNetTrainer': ('nnUNetResEncUNetMPlans', ), 'nnUNetTrainerDA5': ('nnUNetResEncUNetMPlans', ), # 'nnUNetTrainer_probabilisticOversampling_033': ('nnUNetResEncUNetMPlans', ), # 'nnUNetTrainer_probabilisticOversampling_010': ('nnUNetResEncUNetMPlans',), # BN } additional_arguments = f' -num_gpus {num_gpus} --disable_checkpointing' # '' output_file = "/home/isensee/deleteme.txt" with open(output_file, 'w') as f: for tr in use_these_modules.keys(): for p in use_these_modules[tr]: for dataset in use_this.keys(): for config in use_this[dataset]: for fl in folds: command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' if additional_arguments is not None and len(additional_arguments) > 0: command += f' {additional_arguments}' f.write(f'{command}\"\n') ================================================ FILE: nnunetv2/batch_running/jobs.sh ================================================ # lsf22-gpu01 screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" # lsf22-gpu03 screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" # lsf22-gpu05 screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" # lsf22-gpu06 screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" # lsf22-gpu07 screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" # launched as jobs bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing" ================================================ FILE: nnunetv2/batch_running/release_trainings/__init__.py ================================================ ================================================ FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py ================================================ ================================================ FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py ================================================ from typing import Tuple import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.evaluation.evaluate_predictions import load_summary_json from nnunetv2.paths import nnUNet_results from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id from nnunetv2.utilities.file_path_utilities import get_output_folder def collect_results(trainers: dict, datasets: List, output_file: str, configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), folds=tuple(np.arange(5))): results_dirs = (nnUNet_results,) datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] with open(output_file, 'w') as f: for i, d in zip(datasets, datasets_names): for c in configurations: for module in trainers.keys(): for plans in trainers[module]: for r in results_dirs: expected_output_folder = get_output_folder(d, module, plans, c) if isdir(expected_output_folder): results_folds = [] f.write(f"{d},{c},{module},{plans},{r}") for fl in folds: expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) expected_summary_file = join(expected_output_folder_fold, "validation", "summary.json") if not isfile(expected_summary_file): print('expected output file not found:', expected_summary_file) f.write(",") results_folds.append(np.nan) else: foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ 'Dice'] results_folds.append(foreground_mean) f.write(f",{foreground_mean:02.4f}") f.write(f",{np.nanmean(results_folds):02.4f}\n") def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): txt = np.loadtxt(input_file, dtype=str, delimiter=',') num_folds = txt.shape[1] - 6 valid_configs = {} for d in datasets: if isinstance(d, int): d = maybe_convert_to_dataset_name(d) configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) valid_configs[d] = [i for i in configs_in_txt if i in configs] assert max(folds) < num_folds with open(output_file, 'w') as f: f.write("name") for d in valid_configs.keys(): for c in valid_configs[d]: f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) f.write(',mean\n') valid_entries = txt[:, 4] == nnUNet_results for t in trainers.keys(): trainer_locs = valid_entries & (txt[:, 2] == t) for pl in trainers[t]: f.write(f"{t}__{pl}") trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) r = [] for d in valid_configs.keys(): trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) for v in valid_configs[d]: trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) if np.any(trainer_plan_d_config_locs): # we cannot have more than one row assert np.sum(trainer_plan_d_config_locs) == 1 # now check that we have all folds selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] fold_results = selected_row[[i + 5 for i in folds]] if '' in fold_results: print('missing fold in', t, pl, d, v) f.write(",nan") r.append(np.nan) else: mean_dice = np.mean([float(i) for i in fold_results]) f.write(f",{mean_dice:02.4f}") r.append(mean_dice) else: print('missing:', t, pl, d, v) f.write(",nan") r.append(np.nan) f.write(f",{np.mean(r):02.4f}\n") if __name__ == '__main__': use_these_trainers = { 'nnUNetTrainer': ('nnUNetPlans',), 'nnUNetTrainer_v1loss': ('nnUNetPlans',), } all_results_file = join(nnUNet_results, 'customDecResults.csv') datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82] collect_results(use_these_trainers, datasets, all_results_file) folds = (0, 1, 2, 3, 4) configs = ("3d_fullres", "3d_lowres") output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) folds = (0, ) configs = ("3d_fullres", "3d_lowres") output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) ================================================ FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py ================================================ from copy import deepcopy import numpy as np def merge(dict1, dict2): keys = np.unique(list(dict1.keys()) + list(dict2.keys())) keys = np.unique(keys) res = {} for k in keys: all_configs = [] if dict1.get(k) is not None: all_configs += list(dict1[k]) if dict2.get(k) is not None: all_configs += list(dict2[k]) if len(all_configs) > 0: res[k] = tuple(np.unique(all_configs)) return res if __name__ == "__main__": # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of # datasets for evaluation and future development configurations_all = { # 1: ("3d_fullres", "2d"), 2: ("3d_fullres", "2d"), # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 4: ("2d", "3d_fullres"), 5: ("2d", "3d_fullres"), # 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), 20: ("2d", "3d_fullres"), 24: ("2d", "3d_fullres"), 27: ("2d", "3d_fullres"), 35: ("2d", "3d_fullres"), 38: ("2d", "3d_fullres"), # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), # 82: ("2d", "3d_fullres"), # 83: ("2d", "3d_fullres"), } configurations_3d_fr_only = { i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] } configurations_3d_c_only = { i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] } configurations_3d_lr_only = { i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] } configurations_2d_only = { i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] } num_gpus = 1 exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" resources = "-R \"tensorcore\"" gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G" queue = "-q gpu-lowprio" preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train' folds = (0, 1, 2, 3, 4) # use_this = configurations_2d_only # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) # use_this = merge(use_this, configurations_3d_c_only) use_this = configurations_all use_these_modules = { 'nnUNetTrainer': ('nnUNetPlans',), } additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' output_file = "/home/isensee/deleteme.txt" with open(output_file, 'w') as f: for tr in use_these_modules.keys(): for p in use_these_modules[tr]: for dataset in use_this.keys(): for config in use_this[dataset]: for fl in folds: command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' if additional_arguments is not None and len(additional_arguments) > 0: command += f' {additional_arguments}' f.write(f'{command}\"\n') ================================================ FILE: nnunetv2/configuration.py ================================================ import os from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low # resolution axis must be 3x as large as the next largest spacing) default_n_proc_DA = get_allowed_n_proc_DA() ================================================ FILE: nnunetv2/dataset_conversion/Dataset015_018_RibFrac_RibSeg.py ================================================ from copy import deepcopy import numpy as np from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw import SimpleITK as sitk if __name__ == '__main__': """ Download RibFrac dataset. Links are at https://ribfrac.grand-challenge.org/ Download everything. Part1, 2, validation and test Extract EVERYTHING into one folder so that all images and labels are in there. Don't worry they all have unique file names. For RibSeg also download the dataset from https://github.com/M3DV/RibSeg (https://drive.google.com/file/d/1ZZGGrhd0y1fLyOZGo_Y-wlVUP4lkHVgm/view?usp=sharing) and extract in to that same folder (seg only, files end with -rib-seg.nii.gz) """ # extracted traiing.zip file is here base = '/home/isensee/Downloads/RibFrac_all' files = nifti_files(base, join=False) identifiers = np.unique([i.split('-')[0] for i in files]) # RibFrac target_dataset_id = 15 target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibFrac' maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(imagesTs) maybe_mkdir_p(labelsTr) n_tr = 0 for c in identifiers: print(c) img_file = join(base, c + '-image.nii.gz') seg_file = join(base, c + '-label.nii.gz') if not isfile(seg_file): # test case shutil.copy(img_file, join(imagesTs, c + '_0000.nii.gz')) continue n_tr += 1 shutil.copy(img_file, join(imagesTr, c + '_0000.nii.gz')) # we must open seg and map -1 to 5 seg_itk = sitk.ReadImage(seg_file) seg_npy = sitk.GetArrayFromImage(seg_itk) seg_npy[seg_npy == -1] = 5 seg_itk_out = sitk.GetImageFromArray(seg_npy.astype(np.uint8)) seg_itk_out.SetSpacing(seg_itk.GetSpacing()) seg_itk_out.SetDirection(seg_itk.GetDirection()) seg_itk_out.SetOrigin(seg_itk.GetOrigin()) sitk.WriteImage(seg_itk_out, join(labelsTr, c + '.nii.gz')) # - 0: it is background # - 1: it is a displaced rib fracture # - 2: it is a non-displaced rib fracture # - 3: it is a buckle rib fracture # - 4: it is a segmental rib fracture # - -1: it is a rib fracture, but we could not define its type due to # ambiguity, diagnosis difficulty, etc. Ignore it in the # classification task. generate_dataset_json( join(nnUNet_raw, target_dataset_name), channel_names={0: 'CT'}, labels = { 'background': 0, 'fracture': (1, 2, 3, 4, 5), 'displaced rib fracture': 1, 'non-displaced rib fracture': 2, 'buckle rib fracture': 3, 'segmental rib fracture': 4, }, num_training_cases=n_tr, file_ending='.nii.gz', regions_class_order=(5, 1, 2, 3, 4), dataset_name=target_dataset_name, reference='https://ribfrac.grand-challenge.org/' ) # RibSeg # overall I am not happy with the GT quality here. But eh what can I do target_dataset_name_ribfrac = deepcopy(target_dataset_name) target_dataset_id = 18 target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibSeg' maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(labelsTr) # the authors have a google shet where they highlight problems with their dataset: # https://docs.google.com/spreadsheets/d/1lz9liWPy8yHybKCdO3BCA9K76QH8a54XduiZS_9fK70/edit?gid=1416415020#gid=1416415020 # we exclude the cases marked in red. They have unannotated ribs skip_identifiers = [ 'RibFrac452', 'RibFrac485', 'RibFrac490', 'RibFrac471', 'RibFrac462', 'RibFrac487', ] n_tr = 0 dataset = {} for c in identifiers: if c in skip_identifiers: continue print(c) tr_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTr', c + '_0000.nii.gz') ts_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTs', c + '_0000.nii.gz') if isfile(os.path.expandvars(tr_file)): img_file = tr_file elif isfile(os.path.expandvars(ts_file)): img_file = ts_file else: raise RuntimeError(f'Missing image file for identifier {identifiers}') seg_file = join(base, c + '-rib-seg.nii.gz') n_tr += 1 shutil.copy(seg_file, join(labelsTr, c + '.nii.gz')) dataset[c] = { 'images': [img_file], 'label': join('labelsTr', c + '.nii.gz') } generate_dataset_json( join(nnUNet_raw, target_dataset_name), channel_names={0: 'CT'}, labels = { 'background': 0, **{'rib%02.0d' % i: i for i in range(1, 25)} }, num_training_cases=n_tr, file_ending='.nii.gz', dataset_name=target_dataset_name, reference='https://github.com/M3DV/RibSeg, https://ribfrac.grand-challenge.org/', dataset=dataset ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset021_CTAAorta.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw import SimpleITK as sitk if __name__ == '__main__': """ """ # extracted traiing.zip file is here base = '/home/isensee/Downloads/' target_dataset_id = 21 target_dataset_name = f'Dataset{target_dataset_id:03.0f}_CTAAorta' maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(labelsTr) cases = subfiles(join(base, 'images'), join=False, prefix='subject') for case in cases: outname = case.replace('_CTA.mha', '') im = sitk.ReadImage(join(base, 'images', case)) sitk.WriteImage(im, join(imagesTr, outname + '_0000.nii.gz')) seg = sitk.ReadImage(join(base, 'masks', case.replace('_CTA.mha', '_label.mha'))) sitk.WriteImage(seg, join(labelsTr, outname + '.nii.gz')) labels = { "background": 0, "Zone_0": 1, "Innominate": 2, "Zone_1": 3, "Left_Common_Carotid": 4, "Zone_2": 5, "Left_Subclavian_Artery": 6, "Zone_3": 7, "Zone_4": 8, "Zone_5": 9, "Zone_6": 10, "Celiac_Artery": 11, "Zone_7": 12, "SMA": 13, "Zone_8": 14, "Right_Renal_Artery": 15, "Left_Renal_Artery": 16, "Zone_9": 17, "Zone_10_R_(Right_Common_Iliac_Artery)": 18, "Zone_10_L_(Left_Common_Iliac_Artery)": 19, "Right_Internal_Iliac_Artery_Dice_Score": 20, "Left_Internal_Iliac_Artery_Dice_Score": 21, "Zone_11_R_(Right_External_Iliac_Artery)": 22, "Zone_11_L_(Left_External_Iliac_Artery)": 23 } generate_dataset_json( join(nnUNet_raw, target_dataset_name), {0: 'CTA'}, labels, len(cases), '.nii.gz', None, target_dataset_name, overwrite_image_reader_writer='NibabelIOWithReorient', reference='https://aortaseg24.grand-challenge.org/', license='see ref' ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset023_AbdomenAtlas1_1Mini.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': """ Download the dataset from huggingface: https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini#3--download-the-dataset IMPORTANT cases 5196-9262 currently do not have images, just the segmentation. This seems to be a mistake """ base = '/home/isensee/Downloads/AbdomenAtlas/uncompressed' target_dataset_id = 23 target_dataset_name = f'Dataset{target_dataset_id:03.0f}_AbdomenAtlas1.1Mini' cases = subdirs(base, join=False, prefix='BDMAP') maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(labelsTr) for case in cases: if not isfile(join(base, case, 'ct.nii.gz')): print(f'Skipping case {case} due to missing image') continue shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz')) shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz')) class_map = {1: 'aorta', 2: 'gall_bladder', 3: 'kidney_left', 4: 'kidney_right', 5: 'liver', 6: 'pancreas', 7: 'postcava', 8: 'spleen', 9: 'stomach', 10: 'adrenal_gland_left', 11: 'adrenal_gland_right', 12: 'bladder', 13: 'celiac_trunk', 14: 'colon', 15: 'duodenum', 16: 'esophagus', 17: 'femur_left', 18: 'femur_right', 19: 'hepatic_vessel', 20: 'intestine', 21: 'lung_left', 22: 'lung_right', 23: 'portal_vein_and_splenic_vein', 24: 'prostate', 25: 'rectum'} labels = { j: i for i, j in class_map.items() } labels['background'] = 0 generate_dataset_json( join(nnUNet_raw, target_dataset_name), {0: 'CT'}, labels, len(cases), '.nii.gz', None, target_dataset_name, overwrite_image_reader_writer='NibabelIOWithReorient', reference='https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini', license='Creative Commons Attribution Non Commercial Share Alike 4.0; see reference' ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset027_ACDC.py ================================================ import os import shutil from pathlib import Path from typing import List from batchgenerators.utilities.file_and_folder_operations import nifti_files, join, maybe_mkdir_p, save_json from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed import numpy as np def make_out_dirs(dataset_id: int, task_name="ACDC"): dataset_name = f"Dataset{dataset_id:03d}_{task_name}" out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name out_train_dir = out_dir / "imagesTr" out_labels_dir = out_dir / "labelsTr" out_test_dir = out_dir / "imagesTs" os.makedirs(out_dir, exist_ok=True) os.makedirs(out_train_dir, exist_ok=True) os.makedirs(out_labels_dir, exist_ok=True) os.makedirs(out_test_dir, exist_ok=True) return out_dir, out_train_dir, out_labels_dir, out_test_dir def create_ACDC_split(labelsTr_folder: str, seed: int = 1234) -> List[dict[str, List]]: # labelsTr_folder = '/home/isensee/drives/gpu_data_root/OE0441/isensee/nnUNet_raw/nnUNet_raw_remake/Dataset027_ACDC/labelsTr' nii_files = nifti_files(labelsTr_folder, join=False) patients = np.unique([i[:len('patient000')] for i in nii_files]) rs = np.random.RandomState(seed) rs.shuffle(patients) splits = [] for fold in range(5): val_patients = patients[fold::5] train_patients = [i for i in patients if i not in val_patients] val_cases = [i[:-7] for i in nii_files for j in val_patients if i.startswith(j)] train_cases = [i[:-7] for i in nii_files for j in train_patients if i.startswith(j)] splits.append({'train': train_cases, 'val': val_cases}) return splits def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()]) num_training_cases = 0 # Copy training files and corresponding labels. for patient_dir in patients_train: for file in patient_dir.iterdir(): if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: # The stem is 'patient.nii', and the suffix is '.gz'. # We split the stem and append _0000 to the patient part. shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") num_training_cases += 1 elif file.suffix == ".gz" and "_gt" in file.name: shutil.copy(file, labels_dir / file.name.replace("_gt", "")) # Copy test files. for patient_dir in patients_test: for file in patient_dir.iterdir(): if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") return num_training_cases def convert_acdc(src_data_folder: str, dataset_id=27): out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id) num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir) generate_dataset_json( str(out_dir), channel_names={ 0: "cineMRI", }, labels={ "background": 0, "RV": 1, "MLV": 2, "LVC": 3, }, file_ending=".nii.gz", num_training_cases=num_training_cases, ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input_folder", type=str, help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.", ) parser.add_argument( "-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27" ) args = parser.parse_args() print("Converting...") convert_acdc(args.input_folder, args.dataset_id) dataset_name = f"Dataset{args.dataset_id:03d}_{'ACDC'}" labelsTr = join(nnUNet_raw, dataset_name, 'labelsTr') preprocessed_folder = join(nnUNet_preprocessed, dataset_name) maybe_mkdir_p(preprocessed_folder) split = create_ACDC_split(labelsTr) save_json(split, join(preprocessed_folder, 'splits_final.json'), sort_keys=False) print("Done!") ================================================ FILE: nnunetv2/dataset_conversion/Dataset042_BraTS18.py ================================================ import multiprocessing import shutil import SimpleITK as sitk import numpy as np from tqdm import tqdm from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None: # use this for segmentation only!!! # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3 img = sitk.ReadImage(in_file) img_npy = sitk.GetArrayFromImage(img) uniques = np.unique(img_npy) for u in uniques: if u not in [0, 1, 2, 4]: raise RuntimeError('unexpected label') seg_new = np.zeros_like(img_npy) seg_new[img_npy == 4] = 3 seg_new[img_npy == 2] = 1 seg_new[img_npy == 1] = 2 img_corr = sitk.GetImageFromArray(seg_new) img_corr.CopyInformation(img) sitk.WriteImage(img_corr, out_file) def convert_labels_back_to_BraTS(seg: np.ndarray): new_seg = np.zeros_like(seg) new_seg[seg == 1] = 2 new_seg[seg == 3] = 4 new_seg[seg == 2] = 1 return new_seg def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder): a = sitk.ReadImage(join(input_folder, filename)) b = sitk.GetArrayFromImage(a) c = convert_labels_back_to_BraTS(b) d = sitk.GetImageFromArray(c) d.CopyInformation(a) sitk.WriteImage(d, join(output_folder, filename)) def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12): """ reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the """ maybe_mkdir_p(output_folder) nii = subfiles(input_folder, suffix='.nii.gz', join=False) with multiprocessing.get_context("spawn").Pool(num_processes) as p: p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii))) if __name__ == '__main__': brats_data_dir = ... task_id = 42 task_name = "BraTS2018" foldername = "Dataset%03.0d_%s" % (task_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(labelstr) case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='Brats', join=False) case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="Brats", join=False) print("copying hggs") for c in tqdm(case_ids_hgg): shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii')) copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"), join(labelstr, c + '.nii')) print("copying lggs") for c in tqdm(case_ids_lgg): shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii')) copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"), join(labelstr, c + '.nii')) generate_dataset_json(out_base, channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'}, labels={ 'background': 0, 'whole tumor': (1, 2, 3), 'tumor core': (2, 3), 'enhancing tumor': (3,) }, num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)), file_ending='.nii', regions_class_order=(1, 2, 3), license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', dataset_release='1.0') ================================================ FILE: nnunetv2/dataset_conversion/Dataset043_BraTS19.py ================================================ import multiprocessing import shutil import SimpleITK as sitk import numpy as np from tqdm import tqdm from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None: # use this for segmentation only!!! # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3 img = sitk.ReadImage(in_file) img_npy = sitk.GetArrayFromImage(img) uniques = np.unique(img_npy) for u in uniques: if u not in [0, 1, 2, 4]: raise RuntimeError('unexpected label') seg_new = np.zeros_like(img_npy) seg_new[img_npy == 4] = 3 seg_new[img_npy == 2] = 1 seg_new[img_npy == 1] = 2 img_corr = sitk.GetImageFromArray(seg_new) img_corr.CopyInformation(img) sitk.WriteImage(img_corr, out_file) def convert_labels_back_to_BraTS(seg: np.ndarray): new_seg = np.zeros_like(seg) new_seg[seg == 1] = 2 new_seg[seg == 3] = 4 new_seg[seg == 2] = 1 return new_seg def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder): a = sitk.ReadImage(join(input_folder, filename)) b = sitk.GetArrayFromImage(a) c = convert_labels_back_to_BraTS(b) d = sitk.GetImageFromArray(c) d.CopyInformation(a) sitk.WriteImage(d, join(output_folder, filename)) def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12): """ reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the """ maybe_mkdir_p(output_folder) nii = subfiles(input_folder, suffix='.nii.gz', join=False) with multiprocessing.get_context("spawn").Pool(num_processes) as p: p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii))) if __name__ == '__main__': brats_data_dir = ... task_id = 43 task_name = "BraTS2019" foldername = "Dataset%03.0d_%s" % (task_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(labelstr) case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='BraTS', join=False) case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="BraTS", join=False) print("copying hggs") for c in tqdm(case_ids_hgg): shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii')) shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii')) copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"), join(labelstr, c + '.nii')) print("copying lggs") for c in tqdm(case_ids_lgg): shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii')) shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii')) copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"), join(labelstr, c + '.nii')) generate_dataset_json(out_base, channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'}, labels={ 'background': 0, 'whole tumor': (1, 2, 3), 'tumor core': (2, 3), 'enhancing tumor': (3,) }, num_training_cases=(len(case_ids_hgg) + len(case_ids_lgg)), file_ending='.nii', regions_class_order=(1, 2, 3), license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', dataset_release='1.0') ================================================ FILE: nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py ================================================ from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed import tifffile from batchgenerators.utilities.file_and_folder_operations import * import shutil if __name__ == '__main__': """ This is going to be my test dataset for working with tif as input and output images All we do here is copy the files and rename them. Not file conversions take place """ dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM' imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') imagests = join(nnUNet_raw, dataset_name, 'imagesTs') labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') maybe_mkdir_p(imagestr) maybe_mkdir_p(imagests) maybe_mkdir_p(labelstr) # we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train # and Fluo-C3DH-A549-SIM_test train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train' test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test' # with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the # tif files # tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file # that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings # between files. Important! The spacing must align with the axes. # Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first. # The spacing on the website is griven in the wrong axis order. Great. spacing = (1, 0.126, 0.126) # train set for seq in ['01', '02']: images_dir = join(train_source, seq) seg_dir = join(train_source, seq + '_GT', 'SEG') # if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way. # Simpler filenames in the cell tracking challenge would be soooo nice. images = subfiles(images_dir, suffix='.tif', sort=True, join=False) segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False) for i, (im, se) in enumerate(zip(images, segs)): target_name = f'{seq}_image_{i:03d}' # we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input # channels distributed over separate files shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif')) # spacing file! save_json({'spacing': spacing}, join(imagestr, target_name + '.json')) shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif')) # spacing file! save_json({'spacing': spacing}, join(labelstr, target_name + '.json')) # test set, same a strain just without the segmentations for seq in ['01', '02']: images_dir = join(test_source, seq) images = subfiles(images_dir, suffix='.tif', sort=True, join=False) for i, im in enumerate(images): target_name = f'{seq}_image_{i:03d}' shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif')) # spacing file! save_json({'spacing': spacing}, join(imagests, target_name + '.json')) # now we generate the dataset json generate_dataset_json( join(nnUNet_raw, dataset_name), {0: 'fluorescence_microscopy'}, {'background': 0, 'cell': 1}, 60, '.tif' ) # custom split to ensure we are stratifying properly. This dataset only has 2 folds caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)] splits = [] splits.append( {'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]} ) splits.append( {'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]} ) save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json')) ================================================ FILE: nnunetv2/dataset_conversion/Dataset114_MNMs.py ================================================ import csv import os import random from pathlib import Path import nibabel as nib from batchgenerators.utilities.file_and_folder_operations import load_json, save_json from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_preprocessed def read_csv(csv_file: str): patient_info = {} with open(csv_file) as csvfile: reader = csv.reader(csvfile) headers = next(reader) patient_index = headers.index("External code") ed_index = headers.index("ED") es_index = headers.index("ES") vendor_index = headers.index("Vendor") for row in reader: patient_info[row[patient_index]] = { "ed": int(row[ed_index]), "es": int(row[es_index]), "vendor": row[vendor_index], } return patient_info # ------------------------------------------------------------------------------ # Conversion to nnUNet format # ------------------------------------------------------------------------------ def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int): out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs") patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()] patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()] patient_info = read_csv(str(src_data_folder / csv_file_name)) save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir) save_cardiac_phases(patients_test, patient_info, out_test_dir) # There are non-orthonormal direction cosines in the test and validation data. # Not sure if the data should be fixed, or we should skip the problematic data. # patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()] # save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir) generate_dataset_json( str(out_dir), channel_names={ 0: "cineMRI", }, labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3}, file_ending=".nii.gz", num_training_cases=len(patients_train) * 2, # 2 since we have ED and ES for each patient ) def save_cardiac_phases( patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None ): for patient in patients: print(f"Processing patient: {patient.name}") image = nib.load(patient / f"{patient.name}_sa.nii.gz") ed_frame = patient_info[patient.name]["ed"] es_frame = patient_info[patient.name]["es"] save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient) if labels_dir: label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz") save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient) def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path): # Save only extracted diastole and systole slices from the 4D H x W x D x time volume. image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine) image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine) # Labels do not have modality identifiers. Labels always end with 'gt'. suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz" nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}")) nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}")) # ------------------------------------------------------------------------------ # Create custom splits # ------------------------------------------------------------------------------ def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25): existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json") splits = load_json(existing_splits) patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()] # Filter out any patients not in the training set patient_info = { patient: data for patient, data in read_csv(str(src_data_folder / csv_file)).items() if patient in patients_train } # Get train and validation patients for both vendors patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"] patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"] train_a, val_a = get_vendor_split(patients_a, num_val_patients) train_b, val_b = get_vendor_split(patients_b, num_val_patients) # Build filenames from corresponding patient frames train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]] train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]] train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :] train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :] val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]] val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]] for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]: # For each train set, we evaluate on A, B and (A + B) respectively # See table 3 from the original paper for more details. splits.append({"train": train_set, "val": val_a}) splits.append({"train": train_set, "val": val_b}) splits.append({"train": train_set, "val": val_a + val_b}) save_json(splits, existing_splits) def get_vendor_split(patients: list[str], num_val_patients: int): random.shuffle(patients) total_patients = len(patients) num_training_patients = total_patients - num_val_patients return patients[:num_training_patients], patients[num_training_patients:] if __name__ == "__main__": import argparse class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter) parser.add_argument( "-h", "--help", action="help", default=argparse.SUPPRESS, help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet " "format. It can also be used to create additional custom splits, for explicitly training on combinations " "of vendors A and B (see `--custom-splits`).\n" "If you wish to generate the custom splits, run the following pipeline:\n\n" "(1) Run `Dataset114_MNMs -i \n" "(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n" "(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n" "(4) Re-run `Dataset114_MNMs`, with `-s True`.\n" "(5) Re-run training.\n", ) parser.add_argument( "-i", "--input_folder", type=str, default="./data/M&Ms/OpenDataset/", help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing " "folders.", ) parser.add_argument( "-c", "--csv_file_name", type=str, default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv", help="The csv file containing the dataset information.", ), parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.") parser.add_argument( "-s", "--custom_splits", type=bool, default=False, help="Whether to append custom splits for training and testing on different vendors. If True, will create " "splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out " "validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from " "https://arxiv.org/abs/2011.07592 for more info.", ) args = parser.parse_args() args.input_folder = Path(args.input_folder) if args.custom_splits: print("Appending custom splits...") create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id) else: print("Converting...") convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id) print("Done!") ================================================ FILE: nnunetv2/dataset_conversion/Dataset115_EMIDEC.py ================================================ import shutil from pathlib import Path from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json def copy_files(src_data_dir: Path, src_test_dir: Path, train_dir: Path, labels_dir: Path, test_dir: Path): """Copy files from the EMIDEC dataset to the nnUNet dataset folder. Returns the number of training cases.""" patients_train = sorted([f for f in src_data_dir.iterdir() if f.is_dir()]) patients_test = sorted([f for f in src_test_dir.iterdir() if f.is_dir()]) # Copy training files and corresponding labels. for patient in patients_train: train_file = patient / "Images" / f"{patient.name}.nii.gz" label_file = patient / "Contours" / f"{patient.name}.nii.gz" shutil.copy(train_file, train_dir / f"{train_file.stem.split('.')[0]}_0000.nii.gz") shutil.copy(label_file, labels_dir) # Copy test files. for patient in patients_test: test_file = patient / "Images" / f"{patient.name}.nii.gz" shutil.copy(test_file, test_dir / f"{test_file.stem.split('.')[0]}_0000.nii.gz") return len(patients_train) def convert_emidec(src_data_dir: str, src_test_dir: str, dataset_id=27): out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id, task_name="EMIDEC") num_training_cases = copy_files(Path(src_data_dir), Path(src_test_dir), train_dir, labels_dir, test_dir) generate_dataset_json( str(out_dir), channel_names={ 0: "cineMRI", }, labels={ "background": 0, "cavity": 1, "normal_myocardium": 2, "myocardial_infarction": 3, "no_reflow": 4, }, file_ending=".nii.gz", num_training_cases=num_training_cases, ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("-i", "--input_dir", type=str, help="The EMIDEC dataset directory.") parser.add_argument("-t", "--test_dir", type=str, help="The EMIDEC test set directory.") parser.add_argument( "-d", "--dataset_id", required=False, type=int, default=115, help="nnU-Net Dataset ID, default: 115" ) args = parser.parse_args() print("Converting...") convert_emidec(args.input_dir, args.test_dir, args.dataset_id) print("Done!") ================================================ FILE: nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py ================================================ from typing import Dict, Any import os from os.path import join import json import random import multiprocessing import SimpleITK as sitk import numpy as np from tqdm import tqdm def mapping_DS119() -> Dict[int, int]: """Remove all NA Classes and make Class IDs continuous""" mapping = {} mapping.update({i: i for i in range(1, 19)}) # [1-10]->[1-10] | [11-18]->[11-18] mapping.update({i: i - 2 for i in range(21, 29)}) # [21-28]->[19-26] mapping.update({i: i - 4 for i in range(31, 39)}) # [31-38]->[27-34] mapping.update({i: i - 6 for i in range(41, 49)}) # [41-48]->[35-42] return mapping def mapping_DS120() -> Dict[int, int]: """Remove Only Keep Teeth and Jaw Classes""" mapping = {} mapping.update({i: i for i in range(1, 3)}) # [0-2] -> [0-2] mapping.update({i: i - 8 for i in range(11, 19)}) # [11-18]->[3-10] mapping.update({i: i - 10 for i in range(21, 29)}) # [21-28]->[11-18] mapping.update({i: i - 12 for i in range(31, 39)}) # [31-38]->[19-26] mapping.update({i: i - 14 for i in range(41, 49)}) # [41-48]->[27-34] return mapping def mapping_DS121() -> Dict[int, int]: """Remove Only Keep Teeth and Jaw Classes""" mapping = {} mapping.update({i: i - 10 for i in range(11, 19)}) # [11-18]->[3-8] mapping.update({i: i - 12 for i in range(21, 29)}) # [21-28]->[11-16] mapping.update({i: i - 14 for i in range(31, 39)}) # [31-38]->[19-24] mapping.update({i: i - 16 for i in range(41, 49)}) # [41-48]->[27-32] return mapping def load_json(json_file: str) -> Any: with open(json_file, "r") as f: data = json.load(f) return data def write_json(json_file: str, data: Any, indent: int = 4) -> None: with open(json_file, "w") as f: json.dump(data, f, indent=indent) def image_to_nifi(input_path: str, output_path: str) -> None: image_sitk = sitk.ReadImage(input_path) sitk.WriteImage(image_sitk, output_path) def label_mapping(input_path: str, output_path: str, mapping: Dict[int, int] = None) -> None: label_sitk = sitk.ReadImage(input_path) if mapping is not None: label_np = sitk.GetArrayFromImage(label_sitk) label_np_new = np.zeros_like(label_np, dtype=np.uint8) for org_id, new_id in mapping.items(): label_np_new[label_np == org_id] = new_id label_sitk_new = sitk.GetImageFromArray(label_np_new) label_sitk_new.CopyInformation(label_sitk) sitk.WriteImage(label_sitk_new, output_path) else: sitk.WriteImage(label_sitk, output_path) def process_images(files: str, img_dir_in: str, img_dir_out: str, n_processes: int = 12): os.makedirs(img_dir_out, exist_ok=True) iterable = [ { "input_path": join(img_dir_in, file), "output_path": join(img_dir_out, file.replace(".mha", ".nii.gz")), } for file in files ] with multiprocessing.Pool(processes=n_processes) as pool: jobs = [pool.apply_async(image_to_nifi, kwds={**args}) for args in iterable] _ = [job.get() for job in tqdm(jobs, desc="Process Images")] def process_labels( files: str, lbl_dir_in: str, lbl_dir_out: str, mapping: Dict[int, int], n_processes: int = 12 ) -> None: os.makedirs(lbl_dir_out, exist_ok=True) iterable = [ { "input_path": join(lbl_dir_in, file), "output_path": join(lbl_dir_out, file.replace(".mha", ".nii.gz")), "mapping": mapping, } for file in files ] with multiprocessing.Pool(processes=n_processes) as pool: jobs = [pool.apply_async(label_mapping, kwds={**args}) for args in iterable] _ = [job.get() for job in tqdm(jobs, desc="Process Labels...")] def process_ds( root: str, input_ds: str, output_ds: str, mapping: dict, image_link: str = None ) -> None: os.makedirs(join(root, output_ds), exist_ok=True) os.makedirs(join(root, output_ds, "labelsTr"), exist_ok=True) # --- Handle Labels --- # lbl_files = os.listdir(join(root, input_ds, "labelsTr")) lbl_dir_in = join(root, input_ds, "labelsTr") lbl_dir_out = join(root, output_ds, "labelsTr") process_labels(lbl_files, lbl_dir_in, lbl_dir_out, mapping, n_processes=12) # --- Handle Images --- # img_files = os.listdir(join(root, input_ds, "imagesTr")) dataset = {} if image_link is None: img_dir_in = join(root, input_ds, "imagesTr") img_dir_out = join(root, output_ds, "imagesTr") process_images(img_files, img_dir_in, img_dir_out, n_processes=12) else: base_name = [file.replace("_0000.mha", "") for file in img_files] for name in base_name: dataset[name] = { "images": [join("..", image_link, "imagesTr", name + "_0000.nii.gz")], "label": join("labelsTr", name + ".nii.gz"), } # --- Generate dataset.json --- # dataset_json = load_json(join(root, input_ds, "dataset.json")) dataset_json["file_ending"] = ".nii.gz" dataset_json["name"] = output_ds dataset_json["numTraining"] = len(lbl_files) if dataset != {}: dataset_json["dataset"] = dataset label_dict = dataset_json["labels"] label_dict_new = {"background": 0} for k, v in label_dict.items(): if v in mapping.keys(): label_dict_new[k] = mapping[v] dataset_json["labels"] = label_dict_new write_json(join(root, output_ds, "dataset.json"), dataset_json) # --- Generate splits_final.json --- # img_names = [file.replace("_0000.mha", "") for file in img_files] random_seed = 42 random.seed(random_seed) random.shuffle(img_names) split_index = int(len(img_names) * 0.7) # 70:30 split train_files = img_names[:split_index] val_files = img_names[split_index:] train_files.sort() val_files.sort() split = [{"train": train_files, "val": val_files}] write_json(join(root, output_ds, "splits_final.json"), split) if __name__ == "__main__": # Different nnUNet Datasets # Dataset 112: Raw # Dataset 119: Replace NaN classes # Dataset 120: Only Teeth + Jaw Classes # Dataset 121: Only Teeth Classes root = "/media/l727r/data/Teeth_Data/ToothFairy2_Dataset" process_ds(root, "Dataset112_ToothFairy2", "Dataset119_ToothFairy2_All", mapping_DS119(), None) # process_ds( # root, # "Dataset112_ToothFairy2", # "Dataset120_ToothFairy2_JawTeeth", # mapping_DS120(), # "Dataset119_ToothFairy2_All", # ) # process_ds( # root, # "Dataset112_ToothFairy2", # "Dataset121_ToothFairy2_Teeth", # mapping_DS121(), # "Dataset119_ToothFairy2_All", # ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py ================================================ import multiprocessing import shutil from multiprocessing import Pool from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw from skimage import io from acvl_utils.morphology.morphology_helper import generic_filter_components from scipy.ndimage import binary_fill_holes def load_and_convert_case(input_image: str, input_seg: str, output_image: str, output_seg: str, min_component_size: int = 50): seg = io.imread(input_seg) seg[seg == 255] = 1 image = io.imread(input_image) image = image.sum(2) mask = image == (3 * 255) # the dataset has large white areas in which road segmentations can exist but no image information is available. # Remove the road label in these areas mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if sizes[j] > min_component_size]) mask = binary_fill_holes(mask) seg[mask] = 0 io.imsave(output_seg, seg, check_contrast=False) shutil.copy(input_image, output_image) if __name__ == "__main__": # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal' dataset_name = 'Dataset120_RoadSegmentation' imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') imagests = join(nnUNet_raw, dataset_name, 'imagesTs') labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') labelsts = join(nnUNet_raw, dataset_name, 'labelsTs') maybe_mkdir_p(imagestr) maybe_mkdir_p(imagests) maybe_mkdir_p(labelstr) maybe_mkdir_p(labelsts) train_source = join(source, 'training') test_source = join(source, 'testing') with multiprocessing.get_context("spawn").Pool(8) as p: # not all training images have a segmentation valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png') num_train = len(valid_ids) r = [] for v in valid_ids: r.append( p.starmap_async( load_and_convert_case, (( join(train_source, 'input', v), join(train_source, 'output', v), join(imagestr, v[:-4] + '_0000.png'), join(labelstr, v), 50 ),) ) ) # test set valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png') for v in valid_ids: r.append( p.starmap_async( load_and_convert_case, (( join(test_source, 'input', v), join(test_source, 'output', v), join(imagests, v[:-4] + '_0000.png'), join(labelsts, v), 50 ),) ) ) _ = [i.get() for i in r] generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1}, num_train, '.png', dataset_name=dataset_name) ================================================ FILE: nnunetv2/dataset_conversion/Dataset137_BraTS21.py ================================================ import multiprocessing import shutil import SimpleITK as sitk import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None: # use this for segmentation only!!! # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3 img = sitk.ReadImage(in_file) img_npy = sitk.GetArrayFromImage(img) uniques = np.unique(img_npy) for u in uniques: if u not in [0, 1, 2, 4]: raise RuntimeError('unexpected label') seg_new = np.zeros_like(img_npy) seg_new[img_npy == 4] = 3 seg_new[img_npy == 2] = 1 seg_new[img_npy == 1] = 2 img_corr = sitk.GetImageFromArray(seg_new) img_corr.CopyInformation(img) sitk.WriteImage(img_corr, out_file) def convert_labels_back_to_BraTS(seg: np.ndarray): new_seg = np.zeros_like(seg) new_seg[seg == 1] = 2 new_seg[seg == 3] = 4 new_seg[seg == 2] = 1 return new_seg def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder): a = sitk.ReadImage(join(input_folder, filename)) b = sitk.GetArrayFromImage(a) c = convert_labels_back_to_BraTS(b) d = sitk.GetImageFromArray(c) d.CopyInformation(a) sitk.WriteImage(d, join(output_folder, filename)) def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12): """ reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the """ maybe_mkdir_p(output_folder) nii = subfiles(input_folder, suffix='.nii.gz', join=False) with multiprocessing.get_context("spawn").Pool(num_processes) as p: p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii))) if __name__ == '__main__': brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training' task_id = 137 task_name = "BraTS2021" foldername = "Dataset%03.0d_%s" % (task_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(labelstr) case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False) for c in case_ids: shutil.copy(join(brats_data_dir, c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz')) shutil.copy(join(brats_data_dir, c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz')) shutil.copy(join(brats_data_dir, c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz')) shutil.copy(join(brats_data_dir, c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz')) copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + "_seg.nii.gz"), join(labelstr, c + '.nii.gz')) generate_dataset_json(out_base, channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'}, labels={ 'background': 0, 'whole tumor': (1, 2, 3), 'tumor core': (2, 3), 'enhancing tumor': (3, ) }, num_training_cases=len(case_ids), file_ending='.nii.gz', regions_class_order=(1, 2, 3), license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', dataset_release='1.0') ================================================ FILE: nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218): """ AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into the train set. Having a 5-fold cross-validation is superior to a single train:val split """ task_name = "AMOS2022_postChallenge_task1" foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") imagests = join(out_base, "imagesTs") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(imagests) maybe_mkdir_p(labelstr) dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] tr_ctr = 0 for tr in training_identifiers: if int(tr.split("_")[-1]) <= 410: # these are the CT images tr_ctr += 1 shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] for ts in test_identifiers: if int(ts.split("_")[-1]) <= 500: # these are the CT images shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] for vl in val_identifiers: if int(vl.split("_")[-1]) <= 409: # these are the CT images tr_ctr += 1 shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, num_training_cases=tr_ctr, file_ending='.nii.gz', dataset_name=task_name, reference='https://amos22.grand-challenge.org/', release='https://zenodo.org/record/7262581', overwrite_image_reader_writer='NibabelIOWithReorient', description="This is the dataset as released AFTER the challenge event. It has the " "validation set gt in it! We just use the validation images as additional " "training cases because AMOS doesn't specify how they should be used. nnU-Net's" " 5-fold CV is better than some random train:val split.") if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('input_folder', type=str, help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " "Use this link: https://zenodo.org/record/7262581." "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218') args = parser.parse_args() amos_base = args.input_folder convert_amos_task1(amos_base, args.d) ================================================ FILE: nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219): """ AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into the train set. Having a 5-fold cross-validation is superior to a single train:val split """ task_name = "AMOS2022_postChallenge_task2" foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") imagests = join(out_base, "imagesTs") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(imagests) maybe_mkdir_p(labelstr) dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] for tr in training_identifiers: shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] for ts in test_identifiers: shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] for vl in val_identifiers: shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz', dataset_name=task_name, reference='https://amos22.grand-challenge.org/', release='https://zenodo.org/record/7262581', overwrite_image_reader_writer='NibabelIOWithReorient', description="This is the dataset as released AFTER the challenge event. It has the " "validation set gt in it! We just use the validation images as additional " "training cases because AMOS doesn't specify how they should be used. nnU-Net's" " 5-fold CV is better than some random train:val split.") if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('input_folder', type=str, help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " "Use this link: https://zenodo.org/record/7262581." "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219') args = parser.parse_args() amos_base = args.input_folder convert_amos_task2(amos_base, args.d) # /home/isensee/Downloads/amos22/amos22/ ================================================ FILE: nnunetv2/dataset_conversion/Dataset220_KiTS2023.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220): task_name = "KiTS2023" foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(labelstr) cases = subdirs(kits_base_dir, prefix='case_', join=False) for tr in cases: shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz')) generate_dataset_json(out_base, {0: "CT"}, labels={ "background": 0, "kidney": (1, 2, 3), "masses": (2, 3), "tumor": 2 }, regions_class_order=(1, 3, 2), num_training_cases=len(cases), file_ending='.nii.gz', dataset_name=task_name, reference='none', release='0.1.3', overwrite_image_reader_writer='NibabelIOWithReorient', description="KiTS2023") if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('input_folder', type=str, help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)") parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220') args = parser.parse_args() amos_base = args.input_folder convert_kits2023(amos_base, args.d) # /media/isensee/raw_data/raw_datasets/kits23/dataset ================================================ FILE: nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed def convert_autopet(autopet_base_dir:str = '/media/isensee/My Book1/AutoPET/nifti/FDG-PET-CT-Lesions', nnunet_dataset_id: int = 221): task_name = "AutoPETII_2023" foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) # setting up nnU-Net folders out_base = join(nnUNet_raw, foldername) imagestr = join(out_base, "imagesTr") labelstr = join(out_base, "labelsTr") maybe_mkdir_p(imagestr) maybe_mkdir_p(labelstr) patients = subdirs(autopet_base_dir, prefix='PETCT', join=False) n = 0 identifiers = [] for pat in patients: patient_acquisitions = subdirs(join(autopet_base_dir, pat), join=False) for pa in patient_acquisitions: n += 1 identifier = f"{pat}_{pa}" identifiers.append(identifier) if not isfile(join(imagestr, f'{identifier}_0000.nii.gz')): shutil.copy(join(autopet_base_dir, pat, pa, 'CTres.nii.gz'), join(imagestr, f'{identifier}_0000.nii.gz')) if not isfile(join(imagestr, f'{identifier}_0001.nii.gz')): shutil.copy(join(autopet_base_dir, pat, pa, 'SUV.nii.gz'), join(imagestr, f'{identifier}_0001.nii.gz')) if not isfile(join(imagestr, f'{identifier}.nii.gz')): shutil.copy(join(autopet_base_dir, pat, pa, 'SEG.nii.gz'), join(labelstr, f'{identifier}.nii.gz')) generate_dataset_json(out_base, {0: "CT", 1:"CT"}, labels={ "background": 0, "tumor": 1 }, num_training_cases=n, file_ending='.nii.gz', dataset_name=task_name, reference='https://autopet-ii.grand-challenge.org/', release='release', # overwrite_image_reader_writer='NibabelIOWithReorient', description=task_name) # manual split splits = [] for fold in range(5): val_patients = patients[fold :: 5] splits.append( { 'train': [i for i in identifiers if not any([i.startswith(v) for v in val_patients])], 'val': [i for i in identifiers if any([i.startswith(v) for v in val_patients])], } ) pp_out_dir = join(nnUNet_preprocessed, foldername) maybe_mkdir_p(pp_out_dir) save_json(splits, join(pp_out_dir, 'splits_final.json'), sort_keys=False) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('input_folder', type=str, help="The downloaded and extracted autopet dataset (must have PETCT_XXX subfolders)") parser.add_argument('-d', required=False, type=int, default=221, help='nnU-Net Dataset ID, default: 221') args = parser.parse_args() amos_base = args.input_folder convert_autopet(amos_base, args.d) ================================================ FILE: nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py ================================================ import shutil from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.paths import nnUNet_raw from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json if __name__ == '__main__': downloaded_amos_dir = '/home/isensee/amos22/amos22' # downloaded and extracted from https://zenodo.org/record/7155725#.Y0OOCOxBztM target_dataset_id = 223 target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AMOS2022postChallenge' maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(imagesTs) maybe_mkdir_p(labelsTr) train_identifiers = [] # copy images source = join(downloaded_amos_dir, 'imagesTr') source_files = nifti_files(source, join=False) train_identifiers += source_files for s in source_files: shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz')) source = join(downloaded_amos_dir, 'imagesVa') source_files = nifti_files(source, join=False) train_identifiers += source_files for s in source_files: shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz')) source = join(downloaded_amos_dir, 'imagesTs') source_files = nifti_files(source, join=False) for s in source_files: shutil.copy(join(source, s), join(imagesTs, s[:-7] + '_0000.nii.gz')) # copy labels source = join(downloaded_amos_dir, 'labelsTr') source_files = nifti_files(source, join=False) for s in source_files: shutil.copy(join(source, s), join(labelsTr, s)) source = join(downloaded_amos_dir, 'labelsVa') source_files = nifti_files(source, join=False) for s in source_files: shutil.copy(join(source, s), join(labelsTr, s)) old_dataset_json = load_json(join(downloaded_amos_dir, 'dataset.json')) new_labels = {v: k for k, v in old_dataset_json['labels'].items()} generate_dataset_json(join(nnUNet_raw, target_dataset_name), {0: 'nonCT'}, new_labels, num_training_cases=len(train_identifiers), file_ending='.nii.gz', regions_class_order=None, dataset_name=target_dataset_name, reference='https://zenodo.org/record/7155725#.Y0OOCOxBztM', license=old_dataset_json['licence'], # typo in OG dataset.json description=old_dataset_json['description'], release=old_dataset_json['release']) ================================================ FILE: nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py ================================================ from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json if __name__ == '__main__': """ How to train our submission to the JHU benchmark 1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system! 2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed constraints. 3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required, training will take ~28-30h. """ base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini' cases = subdirs(base, join=False, prefix='BDMAP') target_dataset_id = 224 target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AbdomenAtlas1.0' raw_dir = '/home/isensee/drives/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/2024_JHU_benchmark' maybe_mkdir_p(join(raw_dir, target_dataset_name)) imagesTr = join(raw_dir, target_dataset_name, 'imagesTr') labelsTr = join(raw_dir, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(labelsTr) for case in cases: shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz')) shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz')) labels = { "background": 0, "aorta": 1, "gall_bladder": 2, "kidney_left": 3, "kidney_right": 4, "liver": 5, "pancreas": 6, "postcava": 7, "spleen": 8, "stomach": 9 } generate_dataset_json( join(raw_dir, target_dataset_name), {0: 'nonCT'}, # this was a mistake we did at the beginning and we keep it like that here for consistency labels, len(cases), '.nii.gz', None, target_dataset_name, overwrite_image_reader_writer='NibabelIOWithReorient' ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset226_BraTS2024-BraTS-GLI.py ================================================ from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from batchgenerators.utilities.file_and_folder_operations import join, subdirs, subfiles, maybe_mkdir_p from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': """ this dataset does not copy the data into nnunet format and just links to existing data. The dataset can only be used from one machine because the paths in the dataset.json are hard coded """ extracted_BraTS2024_GLI_dir = '/home/isensee/BraTS2024_traindata/training_data1' nnunet_dataset_name = 'BraTS2024-BraTS-GLI' nnunet_dataset_id = 226 dataset_name = f'Dataset{nnunet_dataset_id:03d}_{nnunet_dataset_name}' dataset_dir = join(nnUNet_raw, dataset_name) maybe_mkdir_p(dataset_dir) dataset = {} casenames = subdirs(extracted_BraTS2024_GLI_dir, join=False) for c in casenames: dataset[c] = { 'label': join(extracted_BraTS2024_GLI_dir, c, c + '-seg.nii.gz'), 'images': [ join(extracted_BraTS2024_GLI_dir, c, c + '-t1n.nii.gz'), join(extracted_BraTS2024_GLI_dir, c, c + '-t1c.nii.gz'), join(extracted_BraTS2024_GLI_dir, c, c + '-t2w.nii.gz'), join(extracted_BraTS2024_GLI_dir, c, c + '-t2f.nii.gz') ] } labels = { 'background': 0, 'NETC': 1, 'SNFH': 2, 'ET': 3, 'RC': 4, } generate_dataset_json( dataset_dir, { 0: 'T1', 1: "T1C", 2: "T2W", 3: "T2F" }, labels, num_training_cases=len(dataset), file_ending='.nii.gz', regions_class_order=None, dataset_name=dataset_name, reference='https://www.synapse.org/Synapse:syn53708249/wiki/627500', license='see https://www.synapse.org/Synapse:syn53708249/wiki/627508', dataset=dataset, description='This dataset does not copy the data into nnunet format and just links to existing data. ' 'The dataset can only be used from one machine because the paths in the dataset.json are hard coded' ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset227_TotalSegmentatorMRI.py ================================================ import SimpleITK import nibabel import numpy as np from batchgenerators.utilities.file_and_folder_operations import * import shutil from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': base = '/home/isensee/Downloads/TotalsegmentatorMRI_dataset_v100' cases = subdirs(base, join=False) target_dataset_id = 227 target_dataset_name = f'Dataset{target_dataset_id:3.0f}_TotalSegmentatorMRI' maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr') labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr') maybe_mkdir_p(imagesTr) maybe_mkdir_p(labelsTr) # discover labels label_fnames = nifti_files(join(base, cases[0], 'segmentations'), join=False) label_dict = {i[:-7]: j + 1 for j, i in enumerate(label_fnames)} labelnames = list(label_dict.keys()) label_dict['background'] = 0 for case in cases: img = nibabel.load(join(base, case, 'mri.nii.gz')) nibabel.save(img, join(imagesTr, case + '_0000.nii.gz')) seg_nib = nibabel.load(join(base, case, 'segmentations', labelnames[0] + '.nii.gz')) init_seg_npy = np.asanyarray(seg_nib.dataobj) init_seg_npy[init_seg_npy > 0] = label_dict[labelnames[0]] for labelname in labelnames[1:]: seg = nibabel.load(join(base, case, 'segmentations', labelname + '.nii.gz')) seg = np.asanyarray(seg.dataobj) init_seg_npy[seg > 0] = label_dict[labelname] out = nibabel.Nifti1Image(init_seg_npy, affine=seg_nib.affine, header=seg_nib.header) nibabel.save(out, join(labelsTr, case + '.nii.gz')) generate_dataset_json( join(nnUNet_raw, target_dataset_name), {0: 'MRI'}, # this was a mistake we did at the beginning and we keep it like that here for consistency label_dict, len(cases), '.nii.gz', None, target_dataset_name, overwrite_image_reader_writer='NibabelIOWithReorient', release='1.0.0', reference='https://zenodo.org/records/11367005', license='see reference' ) ================================================ FILE: nnunetv2/dataset_conversion/Dataset987_dummyDataset4.py ================================================ import os from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets if __name__ == '__main__': # creates a dummy dataset where there are no files in imagestr and labelstr source_dataset = 'Dataset004_Hippocampus' target_dataset = 'Dataset987_dummyDataset4' target_dataset_dir = join(nnUNet_raw, target_dataset) maybe_mkdir_p(target_dataset_dir) dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset)) # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy # datasets around between systems. As long as the source dataset is there it will continue working even if # nnUNet_raw is in different locations # paths must be relative to target_dataset_dir!!! for k in dataset.keys(): dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir) dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']] # load old dataset.json dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json')) dataset_json['dataset'] = dataset # save save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False) ================================================ FILE: nnunetv2/dataset_conversion/Dataset989_dummyDataset4_2.py ================================================ import os from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets if __name__ == '__main__': # creates a dummy dataset where there are no files in imagestr and labelstr source_dataset = 'Dataset004_Hippocampus' target_dataset = 'Dataset989_dummyDataset4_2' target_dataset_dir = join(nnUNet_raw, target_dataset) maybe_mkdir_p(target_dataset_dir) dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset)) # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy # datasets around between systems. As long as the source dataset is there it will continue working even if # nnUNet_raw is in different locations # paths must be relative to target_dataset_dir!!! for k in dataset.keys(): dataset[k]['label'] = join('$nnUNet_raw', os.path.relpath(dataset[k]['label'], nnUNet_raw)) dataset[k]['images'] = [join('$nnUNet_raw', os.path.relpath(i, nnUNet_raw)) for i in dataset[k]['images']] # load old dataset.json dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json')) dataset_json['dataset'] = dataset # save save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False) ================================================ FILE: nnunetv2/dataset_conversion/__init__.py ================================================ ================================================ FILE: nnunetv2/dataset_conversion/convert_MSD_dataset.py ================================================ import argparse import multiprocessing import shutil from typing import Optional import SimpleITK as sitk from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets from nnunetv2.configuration import default_num_processes import numpy as np def split_4d_nifti(filename, output_folder): img_itk = sitk.ReadImage(filename) dim = img_itk.GetDimension() file_base = os.path.basename(filename) if dim == 3: shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz")) return elif dim != 4: raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename)) else: img_npy = sitk.GetArrayFromImage(img_itk) spacing = img_itk.GetSpacing() origin = img_itk.GetOrigin() direction = np.array(img_itk.GetDirection()).reshape(4,4) # now modify these to remove the fourth dimension spacing = tuple(list(spacing[:-1])) origin = tuple(list(origin[:-1])) direction = tuple(direction[:-1, :-1].reshape(-1)) for i, t in enumerate(range(img_npy.shape[0])): img = img_npy[t] img_itk_new = sitk.GetImageFromArray(img) img_itk_new.SetSpacing(spacing) img_itk_new.SetOrigin(origin) img_itk_new.SetDirection(direction) sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) def convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None, num_processes: int = default_num_processes) -> None: if source_folder.endswith('/') or source_folder.endswith('\\'): source_folder = source_folder[:-1] labelsTr = join(source_folder, 'labelsTr') imagesTs = join(source_folder, 'imagesTs') imagesTr = join(source_folder, 'imagesTr') assert isdir(labelsTr), f"labelsTr subfolder missing in source folder" assert isdir(imagesTs), f"imagesTs subfolder missing in source folder" assert isdir(imagesTr), f"imagesTr subfolder missing in source folder" dataset_json = join(source_folder, 'dataset.json') assert isfile(dataset_json), f"dataset.json missing in source_folder" # infer source dataset id and name task, dataset_name = os.path.basename(source_folder).split('_') task_id = int(task[4:]) # check if target dataset id is taken target_id = task_id if overwrite_target_id is None else overwrite_target_id existing_datasets = find_candidate_datasets(target_id) assert len(existing_datasets) == 0, f"Target dataset id {target_id} is already taken, please consider changing " \ f"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)" target_dataset_name = f"Dataset{target_id:03d}_{dataset_name}" target_folder = join(nnUNet_raw, target_dataset_name) target_imagesTr = join(target_folder, 'imagesTr') target_imagesTs = join(target_folder, 'imagesTs') target_labelsTr = join(target_folder, 'labelsTr') maybe_mkdir_p(target_imagesTr) maybe_mkdir_p(target_imagesTs) maybe_mkdir_p(target_labelsTr) with multiprocessing.get_context("spawn").Pool(num_processes) as p: results = [] # convert 4d train images source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if not i.startswith('.') and not i.startswith('_')] source_images = [join(imagesTr, i) for i in source_images] results.append( p.starmap_async( split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images)) ) ) # convert 4d test images source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if not i.startswith('.') and not i.startswith('_')] source_images = [join(imagesTs, i) for i in source_images] results.append( p.starmap_async( split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images)) ) ) # copy segmentations source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if not i.startswith('.') and not i.startswith('_')] for s in source_images: shutil.copy(join(labelsTr, s), join(target_labelsTr, s)) [i.get() for i in results] dataset_json = load_json(dataset_json) dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} dataset_json['file_ending'] = ".nii.gz" dataset_json["channel_names"] = dataset_json["modality"] del dataset_json["modality"] del dataset_json["training"] del dataset_json["test"] save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) def entry_point(): parser = argparse.ArgumentParser() parser.add_argument('-i', type=str, required=True, help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: ' '/home/fabian/Downloads/Task05_Prostate') parser.add_argument('-overwrite_id', type=int, required=False, default=None, help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from ' 'folder name). Only use this if you already have an equivalently numbered dataset!') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f'Number of processes used. Default: {default_num_processes}') args = parser.parse_args() convert_msd_dataset(args.i, args.overwrite_id, args.np) if __name__ == '__main__': convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201) ================================================ FILE: nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py ================================================ import shutil from copy import deepcopy from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json from nnunetv2.paths import nnUNet_raw def convert(source_folder, target_dataset_name): """ remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY source_folder """ if isdir(join(nnUNet_raw, target_dataset_name)): raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... ' f'(we might break something). If you are sure you want to proceed, please manually ' f'delete {join(nnUNet_raw, target_dataset_name)}') maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr')) shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr')) if isdir(join(source_folder, 'imagesTs')): shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs')) if isdir(join(source_folder, 'labelsTs')): shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs')) if isdir(join(source_folder, 'imagesVal')): shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal')) if isdir(join(source_folder, 'labelsVal')): shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal')) shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name)) dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json')) del dataset_json['tensorImageSize'] del dataset_json['numTest'] del dataset_json['training'] del dataset_json['test'] dataset_json['channel_names'] = deepcopy(dataset_json['modality']) del dataset_json['modality'] dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} dataset_json['file_ending'] = ".nii.gz" save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) def convert_entry_point(): import argparse parser = argparse.ArgumentParser() parser.add_argument("input_folder", type=str, help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! ' 'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not ' 'know where v1 tasks are.') parser.add_argument("output_dataset_name", type=str, help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!') args = parser.parse_args() convert(args.input_folder, args.output_dataset_name) ================================================ FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py ================================================ import SimpleITK as sitk import shutil import numpy as np from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray: assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!" seg_new = np.ones_like(seg) * label_manager.ignore_label x, y, z = seg.shape # x num_slices = max(1, round(x * percent_of_slices)) selected_slices = np.random.choice(x, num_slices, replace=False) seg_new[selected_slices] = seg[selected_slices] # y num_slices = max(1, round(y * percent_of_slices)) selected_slices = np.random.choice(y, num_slices, replace=False) seg_new[:, selected_slices] = seg[:, selected_slices] # z num_slices = max(1, round(z * percent_of_slices)) selected_slices = np.random.choice(z, num_slices, replace=False) seg_new[:, :, selected_slices] = seg[:, :, selected_slices] return seg_new if __name__ == '__main__': dataset_name = 'IntegrationTest_Hippocampus_regions_ignore' dataset_id = 996 dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" try: existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) if existing_dataset_name != dataset_name: raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " f"nnUNet_results!") except RuntimeError: pass if isdir(join(nnUNet_raw, dataset_name)): shutil.rmtree(join(nnUNet_raw, dataset_name)) source_dataset = maybe_convert_to_dataset_name(4) shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) # additionally optimize entire hippocampus region, remove Posterior dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) dj['labels'] = { 'background': 0, 'hippocampus': (1, 2), 'anterior': 1, 'ignore': 3 } dj['regions_class_order'] = (2, 1) save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) # now add ignore label to segmentation images np.random.seed(1234) lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order')) segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr')) for s in segs: seg_itk = sitk.ReadImage(s) seg_npy = sitk.GetArrayFromImage(seg_itk) seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3) seg_itk_new = sitk.GetImageFromArray(seg_npy) seg_itk_new.CopyInformation(seg_itk) sitk.WriteImage(seg_itk_new, s) ================================================ FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py ================================================ import shutil from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': dataset_name = 'IntegrationTest_Hippocampus_regions' dataset_id = 997 dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" try: existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) if existing_dataset_name != dataset_name: raise FileExistsError( f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " f"nnUNet_results!") except RuntimeError: pass if isdir(join(nnUNet_raw, dataset_name)): shutil.rmtree(join(nnUNet_raw, dataset_name)) source_dataset = maybe_convert_to_dataset_name(4) shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) # additionally optimize entire hippocampus region, remove Posterior dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) dj['labels'] = { 'background': 0, 'hippocampus': (1, 2), 'anterior': 1 } dj['regions_class_order'] = (2, 1) save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) ================================================ FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py ================================================ import shutil from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': dataset_name = 'IntegrationTest_Hippocampus_ignore' dataset_id = 998 dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" try: existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) if existing_dataset_name != dataset_name: raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " f"nnUNet_results!") except RuntimeError: pass if isdir(join(nnUNet_raw, dataset_name)): shutil.rmtree(join(nnUNet_raw, dataset_name)) source_dataset = maybe_convert_to_dataset_name(4) shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) # set class 2 to ignore label dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) dj['labels']['ignore'] = 2 del dj['labels']['Posterior'] save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) ================================================ FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py ================================================ import shutil from batchgenerators.utilities.file_and_folder_operations import isdir, join from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.paths import nnUNet_raw if __name__ == '__main__': dataset_name = 'IntegrationTest_Hippocampus' dataset_id = 999 dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" try: existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) if existing_dataset_name != dataset_name: raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " f"nnUNet_results!") except RuntimeError: pass if isdir(join(nnUNet_raw, dataset_name)): shutil.rmtree(join(nnUNet_raw, dataset_name)) source_dataset = maybe_convert_to_dataset_name(4) shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) ================================================ FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py ================================================ ================================================ FILE: nnunetv2/dataset_conversion/generate_dataset_json.py ================================================ from typing import Tuple, Union, List from batchgenerators.utilities.file_and_folder_operations import save_json, join def generate_dataset_json(output_folder: str, channel_names: dict, labels: dict, num_training_cases: int, file_ending: str, citation: Union[List[str], str] = None, regions_class_order: Tuple[int, ...] = None, dataset_name: str = None, reference: str = None, release: str = None, description: str = None, overwrite_image_reader_writer: str = None, license: str = 'Whoever converted this dataset was lazy and didn\'t look it up!', converted_by: str = "Please enter your name, especially when sharing datasets with others in a common infrastructure!", **kwargs): """ Generates a dataset.json file in the output folder channel_names: Channel names must map the index to the name of the channel, example: { 0: 'T1', 1: 'CT' } Note that the channel names may influence the normalization scheme!! Learn more in the documentation. labels: This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. Example regular labels: { 'background': 0, 'left atrium': 1, 'some other label': 2 } Example region-based training: { 'background': 0, 'whole tumor': (1, 2, 3), 'tumor core': (2, 3), 'enhancing tumor': 3 } Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! num_training_cases: is used to double check all cases are there! file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and segmentations! dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for completeness and as a reminder that these would be great! overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from BaseReaderWriter, place it into nnunet.imageio and reference it here by name kwargs: whatever you put here will be placed in the dataset.json as well """ has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()]) if has_regions: assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \ f"You need that." # channel names need strings as keys keys = list(channel_names.keys()) for k in keys: if not isinstance(k, str): channel_names[str(k)] = channel_names[k] del channel_names[k] # labels need ints as values for l in labels.keys(): value = labels[l] if isinstance(value, (tuple, list)): value = tuple([int(i) for i in value]) labels[l] = value else: labels[l] = int(labels[l]) dataset_json = { 'channel_names': channel_names, # previously this was called 'modality'. I didn't like this so this is # channel_names now. Live with it. 'labels': labels, 'numTraining': num_training_cases, 'file_ending': file_ending, 'licence': license, 'converted_by': converted_by } if dataset_name is not None: dataset_json['name'] = dataset_name if reference is not None: dataset_json['reference'] = reference if release is not None: dataset_json['release'] = release if citation is not None: dataset_json['citation'] = release if description is not None: dataset_json['description'] = description if overwrite_image_reader_writer is not None: dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer if regions_class_order is not None: dataset_json['regions_class_order'] = regions_class_order dataset_json.update(kwargs) save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) ================================================ FILE: nnunetv2/ensembling/__init__.py ================================================ ================================================ FILE: nnunetv2/ensembling/ensemble.py ================================================ import argparse import multiprocessing import shutil from copy import deepcopy from typing import List, Union, Tuple import numpy as np from batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \ maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile from nnunetv2.configuration import default_num_processes from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def average_probabilities(list_of_files: List[str]) -> np.ndarray: assert len(list_of_files), 'At least one file must be given in list_of_files' avg = None for f in list_of_files: if avg is None: avg = np.load(f)['probabilities'] # maybe increase precision to prevent rounding errors if avg.dtype != np.float32: avg = avg.astype(np.float32) else: avg += np.load(f)['probabilities'] avg /= len(list_of_files) return avg def merge_files(list_of_files, output_filename_truncated: str, output_file_ending: str, image_reader_writer: BaseReaderWriter, label_manager: LabelManager, save_probabilities: bool = False): # load the pkl file associated with the first file in list_of_files properties = load_pickle(list_of_files[0][:-4] + '.pkl') # load and average predictions probabilities = average_probabilities(list_of_files) segmentation = label_manager.convert_logits_to_segmentation(probabilities) image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties) if save_probabilities: np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities) save_pickle(probabilities, output_filename_truncated + '.pkl') def ensemble_folders(list_of_input_folders: List[str], output_folder: str, save_merged_probabilities: bool = False, num_processes: int = default_num_processes, dataset_json_file_or_dict: str = None, plans_json_file_or_dict: str = None): """we need too much shit for this function. Problem is that we now have to support region-based training plus multiple input/output formats so there isn't really a way around this. If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction. We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5 folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only works if label dict in dataset.json is the same between these datasets!!!)""" if dataset_json_file_or_dict is not None: if isinstance(dataset_json_file_or_dict, str): dataset_json = load_json(dataset_json_file_or_dict) else: dataset_json = dataset_json_file_or_dict else: dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json')) if plans_json_file_or_dict is not None: if isinstance(plans_json_file_or_dict, str): plans = load_json(plans_json_file_or_dict) else: plans = plans_json_file_or_dict else: plans = load_json(join(list_of_input_folders[0], 'plans.json')) plans_manager = PlansManager(plans) # now collect the files in each of the folders and enforce that all files are present in all folders files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders] # first build a set with all files s = deepcopy(files_per_folder[0]) for f in files_per_folder[1:]: s.update(f) for f in files_per_folder: assert len(s.difference(f)) == 0, "Not all folders contain the same files for ensembling. Please only " \ "provide folders that contain the predictions" lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s] output_files_truncated = [join(output_folder, fi[:-4]) for fi in s] image_reader_writer = plans_manager.image_reader_writer_class() label_manager = plans_manager.get_label_manager(dataset_json) maybe_mkdir_p(output_folder) shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder) with multiprocessing.get_context("spawn").Pool(num_processes) as pool: num_preds = len(s) _ = pool.starmap( merge_files, zip( lists_of_lists_of_files, output_files_truncated, [dataset_json['file_ending']] * num_preds, [image_reader_writer] * num_preds, [label_manager] * num_preds, [save_merged_probabilities] * num_preds ) ) def entry_point_ensemble_folders(): parser = argparse.ArgumentParser() parser.add_argument('-i', nargs='+', type=str, required=True, help='list of input folders') parser.add_argument('-o', type=str, required=True, help='output folder') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f"Numbers of processes used for ensembling. Default: {default_num_processes}") parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output ' 'probabilities in separate .npz files') args = parser.parse_args() ensemble_folders(args.i, args.o, args.save_npz, args.np) def ensemble_crossvalidations(list_of_trained_model_folders: List[str], output_folder: str, folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4), num_processes: int = default_num_processes, overwrite: bool = True) -> None: """ Feature: different configurations can now have different splits """ dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json')) plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json')) # first collect all unique filenames files_per_folder = {} unique_filenames = set() for tr in list_of_trained_model_folders: files_per_folder[tr] = {} for f in folds: if not isdir(join(tr, f'fold_{f}', 'validation')): raise RuntimeError(f'Expected model output directory does not exist. You must train all requested ' f'folds of the specified model.\nModel: {tr}\nFold: {f}') files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) if len(files_here) == 0: raise RuntimeError(f"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your " f"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.") files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) unique_filenames.update(files_per_folder[tr][f]) # verify that all trained_model_folders have all predictions ok = True for tr, fi in files_per_folder.items(): all_files_here = set() for f in folds: all_files_here.update(fi[f]) diff = unique_filenames.difference(all_files_here) if len(diff) > 0: ok = False print(f'model {tr} does not seem to contain all predictions. Missing: {diff}') if not ok: raise RuntimeError('There were missing files, see print statements above this one') # now we need to collect where these files are file_mapping = [] for tr in list_of_trained_model_folders: file_mapping.append({}) for f in folds: for fi in files_per_folder[tr][f]: # check for duplicates assert fi not in file_mapping[-1].keys(), f"Duplicate detected. Case {fi} is present in more than " \ f"one fold of model {tr}." file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi) lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames] output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames] image_reader_writer = plans_manager.image_reader_writer_class() maybe_mkdir_p(output_folder) label_manager = plans_manager.get_label_manager(dataset_json) if not overwrite: tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated] lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]] output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]] with multiprocessing.get_context("spawn").Pool(num_processes) as pool: num_preds = len(lists_of_lists_of_files) _ = pool.starmap( merge_files, zip( lists_of_lists_of_files, output_files_truncated, [dataset_json['file_ending']] * num_preds, [image_reader_writer] * num_preds, [label_manager] * num_preds, [False] * num_preds ) ) shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json')) shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json')) ================================================ FILE: nnunetv2/evaluation/__init__.py ================================================ ================================================ FILE: nnunetv2/evaluation/accumulate_cv_results.py ================================================ import shutil from typing import Union, List, Tuple from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile from nnunetv2.configuration import default_num_processes from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def accumulate_cv_results(trained_model_folder, merged_output_folder: str, folds: Union[List[int], Tuple[int, ...]], num_processes: int = default_num_processes, overwrite: bool = True): """ There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files! """ if overwrite and isdir(merged_output_folder): shutil.rmtree(merged_output_folder) maybe_mkdir_p(merged_output_folder) dataset_json = load_json(join(trained_model_folder, 'dataset.json')) plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) rw = plans_manager.image_reader_writer_class() shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json')) shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json')) did_we_copy_something = False for f in folds: expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation') if not isdir(expected_validation_folder): raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!") predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False) for pf in predicted_files: if overwrite and isfile(join(merged_output_folder, pf)): raise RuntimeError(f'More than one of your folds has a prediction for case {pf}') if overwrite or not isfile(join(merged_output_folder, pf)): shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf)) did_we_copy_something = True if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')): label_manager = plans_manager.get_label_manager(dataset_json) gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr') if not isdir(gt_folder): gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations') compute_metrics_on_folder(gt_folder, merged_output_folder, join(merged_output_folder, 'summary.json'), rw, dataset_json['file_ending'], label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, label_manager.ignore_label, num_processes) ================================================ FILE: nnunetv2/evaluation/evaluate_predictions.py ================================================ import multiprocessing import os from copy import deepcopy from typing import Tuple, List, Union import numpy as np from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \ isfile from nnunetv2.configuration import default_num_processes from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \ determine_reader_writer_from_file_ending from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO # the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]): return str(label_or_region) def key_to_label_or_region(key: str): try: return int(key) except ValueError: key = key.replace('(', '') key = key.replace(')', '') split = key.split(',') return tuple([int(i) for i in split if len(i) > 0]) def save_summary_json(results: dict, output_file: str): """ json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit ourselves """ results_converted = deepcopy(results) # convert keys in mean metrics results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()} # convert metric_per_case for i in range(len(results_converted["metric_per_case"])): results_converted["metric_per_case"][i]['metrics'] = \ {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k] for k in results["metric_per_case"][i]['metrics'].keys()} # sort_keys=True will make foreground_mean the first entry and thus easy to spot save_json(results_converted, output_file, sort_keys=True) def load_summary_json(filename: str): results = load_json(filename) # convert keys in mean metrics results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()} # convert metric_per_case for i in range(len(results["metric_per_case"])): results["metric_per_case"][i]['metrics'] = \ {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k] for k in results["metric_per_case"][i]['metrics'].keys()} return results def labels_to_list_of_regions(labels: List[int]): return [(i,) for i in labels] def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray: if np.isscalar(region_or_label): return segmentation == region_or_label else: mask = np.zeros_like(segmentation, dtype=bool) for r in region_or_label: mask[segmentation == r] = True return mask def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): if ignore_mask is None: use_mask = np.ones_like(mask_ref, dtype=bool) else: use_mask = ~ignore_mask tp = np.sum((mask_ref & mask_pred) & use_mask) fp = np.sum(((~mask_ref) & mask_pred) & use_mask) fn = np.sum((mask_ref & (~mask_pred)) & use_mask) tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) return tp, fp, fn, tn def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter, labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]], ignore_label: int = None) -> dict: # load images seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) ignore_mask = seg_ref == ignore_label if ignore_label is not None else None results = {} results['reference_file'] = reference_file results['prediction_file'] = prediction_file results['metrics'] = {} for r in labels_or_regions: results['metrics'][r] = {} mask_ref = region_or_label_to_mask(seg_ref, r) mask_pred = region_or_label_to_mask(seg_pred, r) tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) if tp + fp + fn == 0: results['metrics'][r]['Dice'] = np.nan results['metrics'][r]['IoU'] = np.nan else: results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn) results['metrics'][r]['IoU'] = tp / (tp + fp + fn) results['metrics'][r]['FP'] = fp results['metrics'][r]['TP'] = tp results['metrics'][r]['FN'] = fn results['metrics'][r]['TN'] = tn results['metrics'][r]['n_pred'] = fp + tp results['metrics'][r]['n_ref'] = fn + tp return results def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str, image_reader_writer: BaseReaderWriter, file_ending: str, regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]], ignore_label: int = None, num_processes: int = default_num_processes, chill: bool = True) -> dict: """ output_file must end with .json; can be None """ if output_file is not None: assert output_file.endswith('.json'), 'output_file should end with .json' files_pred = subfiles(folder_pred, suffix=file_ending, join=False) files_ref = subfiles(folder_ref, suffix=file_ending, join=False) if not chill: present = [isfile(join(folder_pred, i)) for i in files_ref] assert all(present), "Not all files in folder_ref exist in folder_pred" files_ref = [join(folder_ref, i) for i in files_pred] files_pred = [join(folder_pred, i) for i in files_pred] with multiprocessing.get_context("spawn").Pool(num_processes) as pool: # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))): # compute_metrics(*i) results = pool.starmap( compute_metrics, list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))) ) # mean metric per class metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys()) means = {} for r in regions_or_labels: means[r] = {} for m in metric_list: means[r][m] = np.nanmean([i['metrics'][r][m] for i in results]) # foreground mean foreground_mean = {} for m in metric_list: values = [] for k in means.keys(): if k == 0 or k == '0': continue values.append(means[k][m]) foreground_mean[m] = np.mean(values) [recursive_fix_for_json_export(i) for i in results] recursive_fix_for_json_export(means) recursive_fix_for_json_export(foreground_mean) result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean} if output_file is not None: save_summary_json(result, output_file) return result # print('DONE') def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str, output_file: str = None, num_processes: int = default_num_processes, chill: bool = False): dataset_json = load_json(dataset_json_file) # get file ending file_ending = dataset_json['file_ending'] # get reader writer class example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0] rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)() # maybe auto set output file if output_file is None: output_file = join(folder_pred, 'summary.json') lm = PlansManager(plans_file).get_label_manager(dataset_json) compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label, num_processes, chill=chill) def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]], output_file: str = None, num_processes: int = default_num_processes, ignore_label: int = None, chill: bool = False): example_file = subfiles(folder_ref, join=True)[0] file_ending = os.path.splitext(example_file)[-1] rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True, verbose=False)() # maybe auto set output file if output_file is None: output_file = join(folder_pred, 'summary.json') compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill) def evaluate_folder_entry_point(): import argparse parser = argparse.ArgumentParser() parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') parser.add_argument('-djfile', type=str, required=True, help='dataset.json file') parser.add_argument('-pfile', type=str, required=True, help='plans.json file') parser.add_argument('-o', type=str, required=False, default=None, help='Output file. Optional. Default: pred_folder/summary.json') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f'number of processes used. Optional. Default: {default_num_processes}') parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') args = parser.parse_args() compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill) def evaluate_simple_entry_point(): import argparse parser = argparse.ArgumentParser() parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') parser.add_argument('-l', type=int, nargs='+', required=True, help='list of labels') parser.add_argument('-il', type=int, required=False, default=None, help='ignore label') parser.add_argument('-o', type=str, required=False, default=None, help='Output file. Optional. Default: pred_folder/summary.json') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f'number of processes used. Optional. Default: {default_num_processes}') parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt') args = parser.parse_args() compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill) if __name__ == '__main__': folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr' folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation' output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json' image_reader_writer = SimpleITKIO() file_ending = '.nii.gz' regions = labels_to_list_of_regions([1, 2]) ignore_label = None num_processes = 12 compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label, num_processes) ================================================ FILE: nnunetv2/evaluation/find_best_configuration.py ================================================ import argparse import os.path from copy import deepcopy from typing import Union, List, Tuple from batchgenerators.utilities.file_and_folder_operations import ( load_json, join, isdir, listdir, save_json ) from nnunetv2.configuration import default_num_processes from nnunetv2.ensembling.ensemble import ensemble_crossvalidations from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results from nnunetv2.postprocessing.remove_connected_components import determine_postprocessing from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \ convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string from nnunetv2.utilities.plans_handling.plans_handler import PlansManager default_trained_models = tuple([ {'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'}, {'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'}, {'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'}, {'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'}, ]) def filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]): valid = [] for trained_model in model_dict: plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')) # check if configuration exists # 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not # specified in the plans file if trained_model['configuration'] not in plans_manager.available_configurations: print(f"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\n" f"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.") continue # check if trained model output folder exists. This is a requirement. No mercy here. expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'], trained_model['configuration'], fold=None) if not isdir(expected_output_folder): raise RuntimeError(f"Trained model {trained_model} does not have an output folder. " f"Expected: {expected_output_folder}. Please run the training for this model! (don't forget " f"the --npz flag if you want to ensemble multiple configurations)") valid.append(trained_model) return valid def generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer', folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), folder_with_segs_from_prev_stage: str = None, input_folder: str = 'INPUT_FOLDER', output_folder: str = 'OUTPUT_FOLDER', save_npz: bool = False): fold_str = '' for f in folds: fold_str += f' {f}' predict_command = '' trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None) plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) configuration_manager = plans_manager.get_configuration(configuration_name) if 'previous_stage' in plans_manager.available_configurations: prev_stage = configuration_manager.previous_stage_name predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name, folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\n' folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE' predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \ f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}' if folder_with_segs_from_prev_stage is not None: predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}' if save_npz: predict_command += ' --save_probabilities' return predict_command def find_best_configuration(dataset_name_or_id, allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models, allow_ensembling: bool = True, num_processes: int = default_num_processes, overwrite: bool = True, folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), strict: bool = False): dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) all_results = {} allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id) for m in allowed_trained_models: output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None) if not isdir(output_folder) and strict: raise RuntimeError(f'{dataset_name}: The output folder of plans {m["plans"]} configuration ' f'{m["configuration"]} is missing. Please train the model (all requested folds!) first!') identifier = os.path.basename(output_folder) merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite) all_results[identifier] = { 'source': merged_output_folder, 'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice'] } if allow_ensembling: for i in range(len(allowed_trained_models)): for j in range(i + 1, len(allowed_trained_models)): m1, m2 = allowed_trained_models[i], allowed_trained_models[j] output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None) output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None) identifier = get_ensemble_name(output_folder_1, output_folder_2, folds) output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier) ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds, num_processes, overwrite=overwrite) # evaluate ensembled predictions plans_manager = PlansManager(join(output_folder_1, 'plans.json')) dataset_json = load_json(join(output_folder_1, 'dataset.json')) label_manager = plans_manager.get_label_manager(dataset_json) rw = plans_manager.image_reader_writer_class() compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), output_folder_ensemble, join(output_folder_ensemble, 'summary.json'), rw, dataset_json['file_ending'], label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, label_manager.ignore_label, num_processes) all_results[identifier] = \ { 'source': output_folder_ensemble, 'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice'] } # pick best and report inference command best_score = max([i['result'] for i in all_results.values()]) best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score] # may never happen but theoretically # there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles # come after single configs) best_key = best_keys[0] print() print('***All results:***') for k, v in all_results.items(): print(f'{k}: {v["result"]}') print(f'\n*Best*: {best_key}: {all_results[best_key]["result"]}') print() print('***Determining postprocessing for best model/ensemble***') determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'), dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'), num_processes=num_processes, keep_postprocessed_files=True) # in addition to just reading the console output (how it was previously) we should return the information # needed to run the full inference via API return_dict = { 'folds': folds, 'dataset_name_or_id': dataset_name_or_id, 'considered_models': allowed_trained_models, 'ensembling_allowed': allow_ensembling, 'all_results': {i: j['result'] for i, j in all_results.items()}, 'best_model_or_ensemble': { 'result_on_crossval_pre_pp': all_results[best_key]["result"], 'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'], 'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'), 'some_plans_file': join(all_results[best_key]['source'], 'plans.json'), # just needed for label handling, can # come from any of the ensemble members (if any) 'selected_model_or_models': [] } } # convert best key to inference command: if best_key.startswith('ensemble___'): prefix, m1, m2, folds_string = best_key.split('___') tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1) tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2) return_dict['best_model_or_ensemble']['selected_model_or_models'].append( { 'configuration': c1, 'trainer': tr1, 'plans_identifier': pl1, }) return_dict['best_model_or_ensemble']['selected_model_or_models'].append( { 'configuration': c2, 'trainer': tr2, 'plans_identifier': pl2, }) else: tr, pl, c = convert_identifier_to_trainer_plans_config(best_key) return_dict['best_model_or_ensemble']['selected_model_or_models'].append( { 'configuration': c, 'trainer': tr, 'plans_identifier': pl, }) save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json')) # save this so that we don't have to run this # everything someone wants to be reminded of the inference commands. They can just load this and give it to # print_inference_instructions # print it print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt')) return return_dict def print_inference_instructions(inference_info_dict: dict, instructions_file: str = None): def _print_and_maybe_write_to_file(string): print(string) if f_handle is not None: f_handle.write(f'{string}\n') f_handle = open(instructions_file, 'w') if instructions_file is not None else None print() _print_and_maybe_write_to_file('***Run inference like this:***\n') output_folders = [] dataset_name_or_id = inference_info_dict['dataset_name_or_id'] if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1: is_ensemble = True _print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\n') else: is_ensemble = False for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']): tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier'] if is_ensemble: output_folder_name = f"OUTPUT_FOLDER_MODEL_{j+1}" else: output_folder_name = f"OUTPUT_FOLDER" output_folders.append(output_folder_name) _print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'], save_npz=is_ensemble, output_folder=output_folder_name)) if is_ensemble: output_folder_str = output_folders[0] for o in output_folders[1:]: output_folder_str += f' {o}' output_ensemble = f"OUTPUT_FOLDER" _print_and_maybe_write_to_file('\nThe run ensembling with:\n') _print_and_maybe_write_to_file(f"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}") _print_and_maybe_write_to_file("\n***Once inference is completed, run postprocessing like this:***\n") _print_and_maybe_write_to_file(f"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP " f"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} " f"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}") def dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]): """ function is called dumb because it's dumb """ ret = [] for t in trainers: for c in configs: for p in plans: ret.append( {'plans': p, 'configuration': c, 'trainer': t} ) return tuple(ret) def find_best_configuration_entry_point(): parser = argparse.ArgumentParser() parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'], help='List of plan identifiers. Default: nnUNetPlans') parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'], help="List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']") parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'], help='List of trainers. Default: nnUNetTrainer') parser.add_argument('-np', required=False, default=default_num_processes, type=int, help='Number of processes to use for ensembling, postprocessing etc') parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), help='Folds to use. Default: 0 1 2 3 4') parser.add_argument('--disable_ensembling', action='store_true', required=False, help='Set this flag to disable ensembling') parser.add_argument('--no_overwrite', action='store_true', help='If set we will not overwrite already ensembled files etc. May speed up consecutive ' 'runs of this command (why would you want to do that?) at the risk of not updating ' 'outdated results.') args = parser.parse_args() model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p) dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id) find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling, num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f, strict=False) def accumulate_crossval_results_entry_point(): parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint ' 'folder and evaluates them') parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') parser.add_argument('-c', type=str, required=True, default='3d_fullres', help="Configuration") parser.add_argument('-o', type=str, required=False, default=None, help="Output folder. If not specified, the output folder will be located in the trained " \ "model directory (named crossval_results_folds_XXX).") parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), help='Folds to use. Default: 0 1 2 3 4') parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans') parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', help='Trainer class. Default: nnUNetTrainer') args = parser.parse_args() trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c) if args.o is None: merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}') else: merged_output_folder = args.o if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0: raise FileExistsError( f"Output folder {merged_output_folder} exists and is not empty. " f"To avoid data loss, nnUNet requires an empty output folder." ) accumulate_cv_results(trained_model_folder, merged_output_folder, args.f) if __name__ == '__main__': find_best_configuration(4, default_trained_models, True, 8, False, (0, 1, 2, 3, 4)) ================================================ FILE: nnunetv2/experiment_planning/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/dataset_fingerprint/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py ================================================ import multiprocessing import os from time import sleep from typing import List, Type, Union import numpy as np from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p from tqdm import tqdm from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets class DatasetFingerprintExtractor(object): def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False): """ extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a json file in the input_folder Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere else. Don't compute stuff we don't need (except for intensity_statistics_per_channel) """ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) self.verbose = verbose self.show_progress_bar = True self.dataset_name = dataset_name self.input_folder = join(nnUNet_raw, dataset_name) self.num_processes = num_processes self.dataset_json = load_json(join(self.input_folder, 'dataset.json')) self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json) # We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is # also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total # (for the entire dataset) self.num_foreground_voxels_for_intensitystats = 10e7 @staticmethod def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234, num_samples: int = 10000): """ images=image with multiple channels = shape (c, x, y(, z)) """ assert images.ndim == 4 and segmentation.ndim == 4 assert not np.any(np.isnan(segmentation)), "Segmentation contains NaN values. grrrr.... :-(" assert not np.any(np.isnan(images)), "Images contains NaN values. grrrr.... :-(" rs = np.random.RandomState(seed) intensities_per_channel = [] # we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have intensity_statistics_per_channel = [] # segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work foreground_mask = segmentation[0] > 0 percentiles = np.array((0.5, 50.0, 99.5)) for i in range(len(images)): foreground_pixels = images[i][foreground_mask] num_fg = len(foreground_pixels) # sample with replacement so that we don't get issues with cases that have less than num_samples # foreground_pixels. We could also just sample less in those cases but that would than cause these # training cases to be underrepresented intensities_per_channel.append( rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else []) mean, median, mini, maxi, percentile_99_5, percentile_00_5 = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan if num_fg > 0: percentile_00_5, median, percentile_99_5 = np.percentile(foreground_pixels, percentiles) mean = np.mean(foreground_pixels) mini = np.min(foreground_pixels) maxi = np.max(foreground_pixels) intensity_statistics_per_channel.append({ 'mean': mean, 'median': median, 'min': mini, 'max': maxi, 'percentile_99_5': percentile_99_5, 'percentile_00_5': percentile_00_5, }) return intensities_per_channel, intensity_statistics_per_channel @staticmethod def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter], num_samples: int = 10000): rw = reader_writer_class() images, properties_images = rw.read_images(image_files) segmentation, properties_seg = rw.read_seg(segmentation_file) # we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly. # Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't # need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this # way. This is only possible because we are now using our new input/output interface. data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation) foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \ DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped, num_samples=num_samples) spacing = properties_images['spacing'] shape_before_crop = images.shape[1:] shape_after_crop = data_cropped.shape[1:] relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop) return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \ relative_size_after_cropping def run(self, overwrite_existing: bool = False) -> dict: # we do not save the properties file in self.input_folder because that folder might be read-only. We can only # reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name) maybe_mkdir_p(preprocessed_output_folder) properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json') if not isfile(properties_file) or overwrite_existing: reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json, # yikes. Rip the following line self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]) # determine how many foreground voxels we need to sample per training case num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats // len(self.dataset)) r = [] with multiprocessing.get_context("spawn").Pool(self.num_processes) as p: for k in self.dataset.keys(): r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case, ((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class, num_foreground_samples_per_case),))) remaining = list(range(len(self.dataset))) # p is pretty nifti. If we kill workers they just respawn but don't do any work. # So we need to store the original pool of workers. workers = [j for j in p._pool] with tqdm(desc="Extracting dataset fingerprint", total=len(self.dataset), disable=not getattr(self, 'show_progress_bar', True)) as pbar: while len(remaining) > 0: all_alive = all([j.is_alive() for j in workers]) if not all_alive: raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' 'OK jokes aside.\n' 'One of your background processes is missing. This could be because of ' 'an error (look for an error message) or because it was killed ' 'by your OS due to running out of RAM. If you don\'t see ' 'an error message, out of RAM is likely the problem. In that case ' 'reducing the number of workers might help') done = [i for i in remaining if r[i].ready()] for _ in done: pbar.update() remaining = [i for i in remaining if i not in done] sleep(0.1) # results = ptqdm(DatasetFingerprintExtractor.analyze_case, # (training_images_per_case, training_labels_per_case), # processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class, # num_samples=num_foreground_samples_per_case, disable=self.verbose) results = [i.get()[0] for i in r] shapes_after_crop = [r[0] for r in results] spacings = [r[1] for r in results] foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in range(len(results[0][2]))] foreground_intensities_per_channel = np.array(foreground_intensities_per_channel) # we drop this so that the json file is somewhat human readable # foreground_intensity_stats_by_case_and_modality = [r[3] for r in results] median_relative_size_after_cropping = np.median([r[4] for r in results], 0) num_channels = len(self.dataset_json['channel_names'].keys() if 'channel_names' in self.dataset_json.keys() else self.dataset_json['modality'].keys()) intensity_statistics_per_channel = {} percentiles = np.array((0.5, 50.0, 99.5)) for i in range(num_channels): percentile_00_5, median, percentile_99_5 = np.percentile(foreground_intensities_per_channel[i], percentiles) intensity_statistics_per_channel[i] = { 'mean': float(np.mean(foreground_intensities_per_channel[i])), 'median': float(median), 'std': float(np.std(foreground_intensities_per_channel[i])), 'min': float(np.min(foreground_intensities_per_channel[i])), 'max': float(np.max(foreground_intensities_per_channel[i])), 'percentile_99_5': float(percentile_99_5), 'percentile_00_5': float(percentile_00_5), } fingerprint = { "spacings": spacings, "shapes_after_crop": shapes_after_crop, 'foreground_intensity_properties_per_channel': intensity_statistics_per_channel, "median_relative_size_after_cropping": median_relative_size_after_cropping } try: save_json(fingerprint, properties_file) except Exception as e: if isfile(properties_file): os.remove(properties_file) raise e else: fingerprint = load_json(properties_file) return fingerprint if __name__ == '__main__': dfe = DatasetFingerprintExtractor(2, 8) dfe.run(overwrite_existing=False) ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py ================================================ import shutil from copy import deepcopy from typing import List, Union, Tuple import numpy as np import torch from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p from dynamic_network_architectures.architectures.unet import PlainConvUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from nnunetv2.configuration import ANISO_THRESHOLD from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets class ExperimentPlanner(object): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): """ overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may also be affected """ self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) self.suppress_transpose = suppress_transpose self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name) preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name) self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json')) self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json) # load dataset fingerprint if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')): raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNetv2_extract_fingerprint') self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json')) self.anisotropy_threshold = ANISO_THRESHOLD self.UNet_base_num_features = 32 self.UNet_class = PlainConvUNet # the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as # much as possible self.UNet_reference_val_3d = 560000000 # 455600128 550000000 self.UNet_reference_val_2d = 85000000 # 83252480 self.UNet_reference_com_nfeatures = 32 self.UNet_reference_val_corresp_GB = 8 self.UNet_reference_val_corresp_bs_2d = 12 self.UNet_reference_val_corresp_bs_3d = 2 self.UNet_featuremap_min_edge_length = 4 self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) self.UNet_min_batch_size = 2 self.UNet_max_features_2d = 512 self.UNet_max_features_3d = 320 self.max_dataset_covered = 0.05 # we limit the batch size so that no more than 5% of the dataset can be seen # in a single forward/backward pass self.UNet_vram_target_GB = gpu_memory_target_in_gb self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the # median shape then we need a lowres config as well self.preprocessor_name = preprocessor_name self.plans_identifier = plans_name self.overwrite_target_spacing = overwrite_target_spacing assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \ 'used then three floats must be ' \ 'given (as list or tuple)' assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \ 'if overwrite_target_spacing is used then three floats must be given (as list or tuple)' self.plans = None if isfile(join(self.raw_dataset_folder, 'splits_final.json')): _maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'), join(preprocessed_folder, 'splits_final.json')) def determine_reader_writer(self): example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) @staticmethod def static_estimate_VRAM_usage(patch_size: Tuple[int], input_channels: int, output_channels: int, arch_class_name: str, arch_kwargs: dict, arch_kwargs_req_import: Tuple[str, ...]): """ Works for PlainConvUNet, ResidualEncoderUNet """ a = torch.get_num_threads() torch.set_num_threads(get_allowed_n_proc_DA()) # print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}') net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, allow_init=False) ret = net.compute_conv_feature_map_size(patch_size) torch.set_num_threads(a) return ret def determine_resampling(self, *args, **kwargs): """ returns what functions to use for resampling data and seg, respectively. Also returns kwargs resampling function must be callable(data, current_spacing, new_spacing, **kwargs) determine_resampling is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_data = resample_data_or_seg_to_shape resampling_data_kwargs = { "is_seg": False, "order": 3, "order_z": 0, "force_separate_z": None, } resampling_seg = resample_data_or_seg_to_shape resampling_seg_kwargs = { "is_seg": True, "order": 1, "order_z": 0, "force_separate_z": None, } return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs def determine_segmentation_softmax_export_fn(self, *args, **kwargs): """ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be used as target. current_spacing and new_spacing are merely there in case we want to use it somehow determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_fn = resample_data_or_seg_to_shape resampling_fn_kwargs = { "is_seg": False, "order": 1, "order_z": 0, "force_separate_z": None, } return resampling_fn, resampling_fn_kwargs def determine_fullres_target_spacing(self) -> np.ndarray: """ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially impact performance (due to the low number of slices). """ if self.overwrite_target_spacing is not None: return np.array(self.overwrite_target_spacing) spacings = np.vstack(self.dataset_fingerprint['spacings']) sizes = self.dataset_fingerprint['shapes_after_crop'] target = np.percentile(spacings, 50, 0) # todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)] target_size = np.percentile(np.vstack(sizes), 50, 0) # we need to identify datasets for which a different target spacing could be beneficial. These datasets have # the following properties: # - one axis which much lower resolution than the others # - the lowres axis has much less voxels than the others # - (the size in mm of the lowres axis is also reduced) worst_spacing_axis = np.argmax(target) other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] other_spacings = [target[i] for i in other_axes] other_sizes = [target_size[i] for i in other_axes] has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) if has_aniso_spacing and has_aniso_voxels: spacings_of_that_axis = spacings[:, worst_spacing_axis] target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) # don't let the spacing of that axis get higher than the other axes if target_spacing_of_that_axis < max(other_spacings): target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5 target[worst_spacing_axis] = target_spacing_of_that_axis return target def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]: if 'channel_names' not in self.dataset_json.keys(): print('WARNING: "modalities" should be renamed to "channel_names" in dataset.json. This will be ' 'enforced soon!') modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \ self.dataset_json['modality'] normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()] if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.): use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in normalization_schemes] else: use_nonzero_mask_for_norm = [False] * len(normalization_schemes) assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \ 'True or False and cannot be None' normalization_schemes = [i.__name__ for i in normalization_schemes] return normalization_schemes, use_nonzero_mask_for_norm def determine_transpose(self): if self.suppress_transpose: return [0, 1, 2], [0, 1, 2] # todo we should use shapes for that as well. Not quite sure how yet target_spacing = self.determine_fullres_target_spacing() max_spacing_axis = np.argmax(target_spacing) remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] transpose_forward = [max_spacing_axis] + remaining_axes transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)] return transpose_forward, transpose_backward def get_plans_for_configuration(self, spacing: Union[np.ndarray, Tuple[float, ...], List[float]], median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): return str(patch_size) + '_' + str(strides) assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" num_input_channels = len(self.dataset_json['channel_names'].keys() if 'channel_names' in self.dataset_json.keys() else self.dataset_json['modality'].keys()) max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d unet_conv_op = convert_dim_to_conv_op(len(spacing)) # print(spacing, median_shape, approximate_n_voxels_dataset) # find an initial patch size # we first use the spacing to get an aspect ratio tmp = 1 / np.array(spacing) # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same # volume as a patch of size 256 ** 3) # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be # ideal because large initial patch sizes increase computation time because more iterations in the while loop # further down may be required. if len(spacing) == 3: initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] elif len(spacing) == 2: initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] else: raise RuntimeError() # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that # this is different from how nnU-Net v1 does it! # todo patch size can still get too large because we pad the patch size to a multiple of 2**n initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) # use that to get the network topology. Note that this changes the patch_size depending on the number of # pooling operations (must be divisible by 2**num_pool in each axis) network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, 'arch_kwargs': { 'n_stages': num_stages, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], 'conv_bias': True, 'norm_op': norm.__module__ + '.' + norm.__name__, 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, }, '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), } # now estimate vram consumption if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage(patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d # we enforce a batch size of at least two, reference values may have been computed for different batch sizes. # Correct for that in the while loop if statement while (estimate / ref_bs * 2) > reference: # print(patch_size, estimate, reference) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first # subtract shape_must_be_divisible_by, then recompute it and then subtract the # recomputed shape_must_be_divisible_by. Annoying. patch_size = list(patch_size) tmp = deepcopy(patch_size) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by = \ get_pool_and_conv_props(spacing, tmp, self.UNet_featuremap_min_edge_length, 999999) patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] # now recompute topology network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) architecture_kwargs['arch_kwargs'].update({ 'n_stages': num_stages, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size batch_size = round((reference / estimate) * ref_bs) # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() normalization_schemes, mask_is_used_for_norm = \ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() plan = { 'data_identifier': data_identifier, 'preprocessor_name': self.preprocessor_name, 'batch_size': batch_size, 'patch_size': patch_size, 'median_image_size_in_voxels': median_shape, 'spacing': spacing, 'normalization_schemes': normalization_schemes, 'use_mask_for_norm': mask_is_used_for_norm, 'resampling_fn_data': resampling_data.__name__, 'resampling_fn_seg': resampling_seg.__name__, 'resampling_fn_data_kwargs': resampling_data_kwargs, 'resampling_fn_seg_kwargs': resampling_seg_kwargs, 'resampling_fn_probabilities': resampling_softmax.__name__, 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, 'architecture': architecture_kwargs } return plan def plan_experiment(self): """ MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done differently for each configuration but this would cause problems with identifying the correct axes for 2d. There surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes. So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too hard. """ # we use this as a cache to prevent having to instantiate the architecture too often. Saves computation time _tmp = {} # first get transpose transpose_forward, transpose_backward = self.determine_transpose() # get fullres spacing and transpose it fullres_spacing = self.determine_fullres_target_spacing() fullres_spacing_transposed = fullres_spacing[transpose_forward] # get transposed new median shape (what we would have after resampling) new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])] new_median_shape = np.median(new_shapes, 0) new_median_shape_transposed = new_median_shape[transpose_forward] approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) * self.dataset_json['numTraining']) # only run 3d if this is a 3d dataset if new_median_shape_transposed[0] != 1: plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, new_median_shape_transposed, self.generate_data_identifier('3d_fullres'), approximate_n_voxels_dataset, _tmp) # maybe add 3d_lowres as well patch_size_fullres = plan_3d_fullres['patch_size'] median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64) num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64) plan_3d_lowres = None lowres_spacing = deepcopy(plan_3d_fullres['spacing']) spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation! while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold: # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they # is/are similar (factor 2) to the other ax(i/e)s. max_spacing = max(lowres_spacing) if np.any((max_spacing / lowres_spacing) > 2): lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor else: lowres_spacing *= spacing_increase_factor median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed, dtype=np.float64) # print(lowres_spacing) plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, tuple([round(i) for i in plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed]), self.generate_data_identifier('3d_lowres'), float(np.prod(median_num_voxels) * self.dataset_json['numTraining']), _tmp) num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64) print(f'Attempting to find 3d_lowres config. ' f'\nCurrent spacing: {lowres_spacing}. ' f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. ' f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}') if np.prod(new_median_shape_transposed, dtype=np.float64) / median_num_voxels < 2: print(f'Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. ' f'3d_fullres: {new_median_shape_transposed}, ' f'3d_lowres: {[round(i) for i in plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed]}') plan_3d_lowres = None if plan_3d_lowres is not None: plan_3d_lowres['batch_dice'] = False plan_3d_fullres['batch_dice'] = True else: plan_3d_fullres['batch_dice'] = False else: plan_3d_fullres = None plan_3d_lowres = None # 2D configuration plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], new_median_shape_transposed[1:], self.generate_data_identifier('2d'), approximate_n_voxels_dataset, _tmp) plan_2d['batch_dice'] = True print('2D U-Net configuration:') print(plan_2d) print() # median spacing and shape, just for reference when printing the plans median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward] median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward] # instead of writing all that into the plans we just copy the original file. More files, but less crowded # per file. shutil.copy(join(self.raw_dataset_folder, 'dataset.json'), join(nnUNet_preprocessed, self.dataset_name, 'dataset.json')) # json is ###. I hate it... "Object of type int64 is not JSON serializable" plans = { 'dataset_name': self.dataset_name, 'plans_name': self.plans_identifier, 'original_median_spacing_after_transp': [float(i) for i in median_spacing], 'original_median_shape_after_transp': [int(round(i)) for i in median_shape], 'image_reader_writer': self.determine_reader_writer().__name__, 'transpose_forward': [int(i) for i in transpose_forward], 'transpose_backward': [int(i) for i in transpose_backward], 'configurations': {'2d': plan_2d}, 'experiment_planner_used': self.__class__.__name__, 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': self.dataset_fingerprint[ 'foreground_intensity_properties_per_channel'] } if plan_3d_lowres is not None: plans['configurations']['3d_lowres'] = plan_3d_lowres if plan_3d_fullres is not None: plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres' print('3D lowres U-Net configuration:') print(plan_3d_lowres) print() if plan_3d_fullres is not None: plans['configurations']['3d_fullres'] = plan_3d_fullres print('3D fullres U-Net configuration:') print(plan_3d_fullres) print() if plan_3d_lowres is not None: plans['configurations']['3d_cascade_fullres'] = { 'inherits_from': '3d_fullres', 'previous_stage': '3d_lowres' } self.plans = plans self.save_plans(plans) return plans def save_plans(self, plans): recursive_fix_for_json_export(plans) plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json') # we don't want to overwrite potentially existing custom configurations every time this is executed. So let's # read the plans file if it already exists and keep any non-default configurations if isfile(plans_file): old_plans = load_json(plans_file) old_configurations = old_plans['configurations'] for c in plans['configurations'].keys(): if c in old_configurations.keys(): del (old_configurations[c]) plans['configurations'].update(old_configurations) maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name)) save_json(plans, plans_file, sort_keys=False) print(f"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}") def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ return self.plans_identifier + '_' + configuration_name def load_plans(self, fname: str): self.plans = load_json(fname) def _maybe_copy_splits_file(splits_file: str, target_fname: str): if not isfile(target_fname): shutil.copy(splits_file, target_fname) else: # split already exists, do not copy, but check that the splits match. # This code allows target_fname to contain more splits than splits_file. This is OK. splits_source = load_json(splits_file) splits_target = load_json(target_fname) # all folds in the source file must match the target file for i in range(len(splits_source)): train_source = set(splits_source[i]['train']) train_target = set(splits_target[i]['train']) assert train_target == train_source val_source = set(splits_source[i]['val']) val_target = set(splits_target[i]['val']) assert val_source == val_target if __name__ == '__main__': ExperimentPlanner(2, 8).plan_experiment() ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/network_topology.py ================================================ from copy import deepcopy import numpy as np def get_shape_must_be_divisible_by(net_numpool_per_axis): return 2 ** np.array(net_numpool_per_axis) def pad_shape(shape, must_be_divisible_by): """ pads shape so that it is divisible by must_be_divisible_by :param shape: :param must_be_divisible_by: :return: """ if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): must_be_divisible_by = [must_be_divisible_by] * len(shape) else: assert len(must_be_divisible_by) == len(shape) new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] for i in range(len(shape)): if shape[i] % must_be_divisible_by[i] == 0: new_shp[i] -= must_be_divisible_by[i] new_shp = np.array(new_shp).astype(int) return new_shp def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): """ this is the same as get_pool_and_conv_props_v2 from old nnunet :param spacing: :param patch_size: :param min_feature_map_size: min edge length of feature maps in bottleneck :param max_numpool: :return: """ # todo review this code dim = len(spacing) current_spacing = deepcopy(list(spacing)) current_size = deepcopy(list(patch_size)) pool_op_kernel_sizes = [[1] * len(spacing)] conv_kernel_sizes = [] num_pool_per_axis = [0] * dim kernel_size = [1] * dim while True: # exclude axes that we cannot pool further because of min_feature_map_size constraint valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] if len(valid_axes_for_pool) < 1: break spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] # find axis that are within factor of 2 within smallest spacing min_spacing_of_valid = min(spacings_of_axes) valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] # max_numpool constraint valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] if len(valid_axes_for_pool) == 1: if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: pass else: break if len(valid_axes_for_pool) < 1: break # now we need to find kernel sizes # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within # factor 2 of min_spacing. Once they are 3 they remain 3 for d in range(dim): if kernel_size[d] == 3: continue else: if current_spacing[d] / min(current_spacing) < 2: kernel_size[d] = 3 other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] pool_kernel_sizes = [0] * dim for v in valid_axes_for_pool: pool_kernel_sizes[v] = 2 num_pool_per_axis[v] += 1 current_spacing[v] *= 2 current_size[v] = np.ceil(current_size[v] / 2) for nv in other_axes: pool_kernel_sizes[nv] = 1 pool_op_kernel_sizes.append(pool_kernel_sizes) conv_kernel_sizes.append(deepcopy(kernel_size)) #print(conv_kernel_sizes) must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) patch_size = pad_shape(patch_size, must_be_divisible_by) def _to_tuple(lst): return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst) # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here conv_kernel_sizes.append([3]*dim) return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/resampling/planners_no_resampling.py ================================================ from typing import Union, List, Tuple from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \ nnUNetPlannerResEncL from nnunetv2.preprocessing.resampling.no_resampling import no_resampling_hack class nnUNetPlannerResEncL_noResampling(nnUNetPlannerResEncL): """ This planner will generate 3d_lowres as well. Don't trust it. Everything will remain in the original shape. No resampling will ever be done. """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 24, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_noResampling', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ return self.plans_identifier + '_' + configuration_name def determine_resampling(self, *args, **kwargs): """ returns what functions to use for resampling data and seg, respectively. Also returns kwargs resampling function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs) determine_resampling is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_data = no_resampling_hack resampling_data_kwargs = {} resampling_seg = no_resampling_hack resampling_seg_kwargs = {} return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs def determine_segmentation_softmax_export_fn(self, *args, **kwargs): """ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be used as target. current_spacing and new_spacing are merely there in case we want to use it somehow determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_fn = no_resampling_hack resampling_fn_kwargs = {} return resampling_fn, resampling_fn_kwargs ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py ================================================ from typing import Union, List, Tuple from nnunetv2.configuration import ANISO_THRESHOLD from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \ nnUNetPlannerResEncL from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 24, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ return self.plans_identifier + '_' + configuration_name def determine_resampling(self, *args, **kwargs): """ returns what functions to use for resampling data and seg, respectively. Also returns kwargs resampling function must be callable(data, current_spacing, new_spacing, **kwargs) determine_resampling is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_data = resample_torch_fornnunet resampling_data_kwargs = { "is_seg": False, 'force_separate_z': False, 'memefficient_seg_resampling': False } resampling_seg = resample_torch_fornnunet resampling_seg_kwargs = { "is_seg": True, 'force_separate_z': False, 'memefficient_seg_resampling': False } return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs def determine_segmentation_softmax_export_fn(self, *args, **kwargs): """ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be used as target. current_spacing and new_spacing are merely there in case we want to use it somehow determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_fn = resample_torch_fornnunet resampling_fn_kwargs = { "is_seg": False, 'force_separate_z': False, 'memefficient_seg_resampling': False } return resampling_fn, resampling_fn_kwargs class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 24, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ return self.plans_identifier + '_' + configuration_name def determine_resampling(self, *args, **kwargs): """ returns what functions to use for resampling data and seg, respectively. Also returns kwargs resampling function must be callable(data, current_spacing, new_spacing, **kwargs) determine_resampling is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_data = resample_torch_fornnunet resampling_data_kwargs = { "is_seg": False, 'force_separate_z': None, 'memefficient_seg_resampling': False, 'separate_z_anisotropy_threshold': ANISO_THRESHOLD } resampling_seg = resample_torch_fornnunet resampling_seg_kwargs = { "is_seg": True, 'force_separate_z': None, 'memefficient_seg_resampling': False, 'separate_z_anisotropy_threshold': ANISO_THRESHOLD } return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs def determine_segmentation_softmax_export_fn(self, *args, **kwargs): """ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be used as target. current_spacing and new_spacing are merely there in case we want to use it somehow determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_fn = resample_torch_fornnunet resampling_fn_kwargs = { "is_seg": False, 'force_separate_z': None, 'memefficient_seg_resampling': False, 'separate_z_anisotropy_threshold': ANISO_THRESHOLD } return resampling_fn, resampling_fn_kwargs class nnUNetPlanner_torchres(ExperimentPlanner): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ return self.plans_identifier + '_' + configuration_name def determine_resampling(self, *args, **kwargs): """ returns what functions to use for resampling data and seg, respectively. Also returns kwargs resampling function must be callable(data, current_spacing, new_spacing, **kwargs) determine_resampling is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_data = resample_torch_fornnunet resampling_data_kwargs = { "is_seg": False, 'force_separate_z': False, 'memefficient_seg_resampling': False } resampling_seg = resample_torch_fornnunet resampling_seg_kwargs = { "is_seg": True, 'force_separate_z': False, 'memefficient_seg_resampling': False } return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs def determine_segmentation_softmax_export_fn(self, *args, **kwargs): """ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be used as target. current_spacing and new_spacing are merely there in case we want to use it somehow determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different functions for each configuration """ resampling_fn = resample_torch_fornnunet resampling_fn_kwargs = { "is_seg": False, 'force_separate_z': False, 'memefficient_seg_resampling': False } return resampling_fn, resampling_fn_kwargs ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py ================================================ import numpy as np from copy import deepcopy from typing import Union, List, Tuple from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props class ResEncUNetPlanner(ExperimentPlanner): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as # much as possible self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ if configuration_name == '2d' or configuration_name == '3d_fullres': # we do not deviate from ExperimentPlanner so we can reuse its data return 'nnUNetPlans' + '_' + configuration_name else: return self.plans_identifier + '_' + configuration_name def get_plans_for_configuration(self, spacing: Union[np.ndarray, Tuple[float, ...], List[float]], median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): return str(patch_size) + '_' + str(strides) assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" num_input_channels = len(self.dataset_json['channel_names'].keys() if 'channel_names' in self.dataset_json.keys() else self.dataset_json['modality'].keys()) max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d unet_conv_op = convert_dim_to_conv_op(len(spacing)) # print(spacing, median_shape, approximate_n_voxels_dataset) # find an initial patch size # we first use the spacing to get an aspect ratio tmp = 1 / np.array(spacing) # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same # volume as a patch of size 256 ** 3) # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be # ideal because large initial patch sizes increase computation time because more iterations in the while loop # further down may be required. if len(spacing) == 3: initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] elif len(spacing) == 2: initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] else: raise RuntimeError() # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that # this is different from how nnU-Net v1 does it! # todo patch size can still get too large because we pad the patch size to a multiple of 2**n initial_patch_size = np.minimum(initial_patch_size, median_shape[:len(spacing)]) # use that to get the network topology. Note that this changes the patch_size depending on the number of # pooling operations (must be divisible by 2**num_pool in each axis) network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, 'arch_kwargs': { 'n_stages': num_stages, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], 'conv_bias': True, 'norm_op': norm.__module__ + '.' + norm.__name__, 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, }, '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), } # now estimate vram consumption if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage(patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first # subtract shape_must_be_divisible_by, then recompute it and then subtract the # recomputed shape_must_be_divisible_by. Annoying. patch_size = list(patch_size) tmp = deepcopy(patch_size) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by = \ get_pool_and_conv_props(spacing, tmp, self.UNet_featuremap_min_edge_length, 999999) patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] # now recompute topology network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) architecture_kwargs['arch_kwargs'].update({ 'n_stages': num_stages, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d batch_size = round((reference / estimate) * ref_bs) # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() normalization_schemes, mask_is_used_for_norm = \ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() plan = { 'data_identifier': data_identifier, 'preprocessor_name': self.preprocessor_name, 'batch_size': batch_size, 'patch_size': patch_size, 'median_image_size_in_voxels': median_shape, 'spacing': spacing, 'normalization_schemes': normalization_schemes, 'use_mask_for_norm': mask_is_used_for_norm, 'resampling_fn_data': resampling_data.__name__, 'resampling_fn_seg': resampling_seg.__name__, 'resampling_fn_data_kwargs': resampling_data_kwargs, 'resampling_fn_seg_kwargs': resampling_seg_kwargs, 'resampling_fn_probabilities': resampling_softmax.__name__, 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, 'architecture': architecture_kwargs } return plan if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, n_conv_per_stage_decoder=(1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned # from this one to match the regular nnunetplans more closely net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py ================================================ import warnings import numpy as np from copy import deepcopy from typing import Union, List, Tuple from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props class ResEncUNetPlanner(ExperimentPlanner): def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as # much as possible self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) def generate_data_identifier(self, configuration_name: str) -> str: """ configurations are unique within each plans file but different plans file can have configurations with the same name. In order to distinguish the associated data we need a data identifier that reflects not just the config but also the plans it originates from """ if configuration_name == '2d' or configuration_name == '3d_fullres': # we do not deviate from ExperimentPlanner so we can reuse its data return 'nnUNetPlans' + '_' + configuration_name else: return self.plans_identifier + '_' + configuration_name def get_plans_for_configuration(self, spacing: Union[np.ndarray, Tuple[float, ...], List[float]], median_shape: Union[np.ndarray, Tuple[int, ...]], data_identifier: str, approximate_n_voxels_dataset: float, _cache: dict) -> dict: def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]: return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for i in range(num_stages)]) def _keygen(patch_size, strides): return str(patch_size) + '_' + str(strides) assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" num_input_channels = len(self.dataset_json['channel_names'].keys() if 'channel_names' in self.dataset_json.keys() else self.dataset_json['modality'].keys()) max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d unet_conv_op = convert_dim_to_conv_op(len(spacing)) # print(spacing, median_shape, approximate_n_voxels_dataset) # find an initial patch size # we first use the spacing to get an aspect ratio tmp = 1 / np.array(spacing) # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same # volume as a patch of size 256 ** 3) # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be # ideal because large initial patch sizes increase computation time because more iterations in the while loop # further down may be required. if len(spacing) == 3: initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] elif len(spacing) == 2: initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] else: raise RuntimeError() # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that # this is different from how nnU-Net v1 does it! # todo patch size can still get too large because we pad the patch size to a multiple of 2**n initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) # use that to get the network topology. Note that this changes the patch_size depending on the number of # pooling operations (must be divisible by 2**num_pool in each axis) network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) norm = get_matching_instancenorm(unet_conv_op) architecture_kwargs = { 'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__, 'arch_kwargs': { 'n_stages': num_stages, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], 'conv_bias': True, 'norm_op': norm.__module__ + '.' + norm.__name__, 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, }, '_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'), } # now estimate vram consumption if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage(patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # how large is the reference for us here (batch size etc)? # adapt for our vram target reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) while estimate > reference: # print(patch_size) # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the # aspect ratio the most (that is the largest relative to median shape) axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1] # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first # subtract shape_must_be_divisible_by, then recompute it and then subtract the # recomputed shape_must_be_divisible_by. Annoying. patch_size = list(patch_size) tmp = deepcopy(patch_size) tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] _, _, _, _, shape_must_be_divisible_by = \ get_pool_and_conv_props(spacing, tmp, self.UNet_featuremap_min_edge_length, 999999) patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] # now recompute topology network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, self.UNet_featuremap_min_edge_length, 999999) num_stages = len(pool_op_kernel_sizes) architecture_kwargs['arch_kwargs'].update({ 'n_stages': num_stages, 'kernel_sizes': conv_kernel_sizes, 'strides': pool_op_kernel_sizes, 'features_per_stage': _features_per_stage(num_stages, max_num_features), 'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages], 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], }) if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys(): estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)] else: estimate = self.static_estimate_VRAM_usage( patch_size, num_input_channels, len(self.dataset_json['labels'].keys()), architecture_kwargs['network_class_name'], architecture_kwargs['arch_kwargs'], architecture_kwargs['_kw_requires_import'], ) _cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was # executed. If not, additional vram headroom is used to increase batch size ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d batch_size = round((reference / estimate) * ref_bs) # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot # go smaller than self.UNet_min_batch_size though bs_corresponding_to_5_percent = round( approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64)) batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() normalization_schemes, mask_is_used_for_norm = \ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() plan = { 'data_identifier': data_identifier, 'preprocessor_name': self.preprocessor_name, 'batch_size': batch_size, 'patch_size': patch_size, 'median_image_size_in_voxels': median_shape, 'spacing': spacing, 'normalization_schemes': normalization_schemes, 'use_mask_for_norm': mask_is_used_for_norm, 'resampling_fn_data': resampling_data.__name__, 'resampling_fn_seg': resampling_seg.__name__, 'resampling_fn_data_kwargs': resampling_data_kwargs, 'resampling_fn_seg_kwargs': resampling_seg_kwargs, 'resampling_fn_probabilities': resampling_softmax.__name__, 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, 'architecture': architecture_kwargs } return plan class nnUNetPlannerResEncM(ResEncUNetPlanner): """ Target is ~9-11 GB VRAM max -> older Titan, RTX 2080ti """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 8, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): if gpu_memory_target_in_gb != 8: warnings.warn("WARNING: You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb. " f"Expected 8, got {gpu_memory_target_in_gb}." "You should only see this warning if you modified this value intentionally!!") super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet self.UNet_vram_target_GB = gpu_memory_target_in_gb self.UNet_reference_val_corresp_GB = 8 # this is supposed to give the same GPU memory requirement as the default nnU-Net self.UNet_reference_val_3d = 680000000 self.UNet_reference_val_2d = 135000000 self.max_dataset_covered = 1 class nnUNetPlannerResEncL(ResEncUNetPlanner): """ Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000 """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 24, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): if gpu_memory_target_in_gb != 24: warnings.warn("WARNING: You are running nnUNetPlannerL with a non-standard gpu_memory_target_in_gb. " f"Expected 24, got {gpu_memory_target_in_gb}." "You should only see this warning if you modified this value intentionally!!") super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet self.UNet_vram_target_GB = gpu_memory_target_in_gb self.UNet_reference_val_corresp_GB = 24 self.UNet_reference_val_3d = 2100000000 # 1840000000 self.UNet_reference_val_2d = 380000000 # 352666667 self.max_dataset_covered = 1 class nnUNetPlannerResEncXL(ResEncUNetPlanner): """ Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation """ def __init__(self, dataset_name_or_id: Union[str, int], gpu_memory_target_in_gb: float = 40, preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans', overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, suppress_transpose: bool = False): if gpu_memory_target_in_gb != 40: warnings.warn("WARNING: You are running nnUNetPlannerXL with a non-standard gpu_memory_target_in_gb. " f"Expected 40, got {gpu_memory_target_in_gb}." "You should only see this warning if you modified this value intentionally!!") super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, overwrite_target_spacing, suppress_transpose) self.UNet_class = ResidualEncoderUNet self.UNet_vram_target_GB = gpu_memory_target_in_gb self.UNet_reference_val_corresp_GB = 40 self.UNet_reference_val_3d = 3600000000 self.UNet_reference_val_2d = 560000000 self.max_dataset_covered = 1 if __name__ == '__main__': # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, n_conv_per_stage_decoder=(1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned # from this one to match the regular nnunetplans more closely net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 ================================================ FILE: nnunetv2/experiment_planning/plan_and_preprocess_api.py ================================================ import warnings from typing import List, Type, Optional, Tuple, Union from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json import nnunetv2 from nnunetv2.configuration import default_num_processes from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.plans_handling.plans_handler import PlansManager from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets def extract_fingerprint_dataset(dataset_id: int, fingerprint_extractor_class: Type[ DatasetFingerprintExtractor] = DatasetFingerprintExtractor, num_processes: int = default_num_processes, check_dataset_integrity: bool = False, clean: bool = True, verbose: bool = True, show_progress_bar: bool = True): """ Returns the fingerprint as a dictionary (additionally to saving it) """ dataset_name = convert_id_to_dataset_name(dataset_id) print(dataset_name) if check_dataset_integrity: verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes) fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose) if hasattr(fpe, 'show_progress_bar'): fpe.show_progress_bar = show_progress_bar return fpe.run(overwrite_existing=clean) def extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor', num_processes: int = default_num_processes, check_dataset_integrity: bool = False, clean: bool = True, verbose: bool = True, show_progress_bar: bool = True): """ clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where we don't want to rerun fingerprint extraction every time. """ fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), fingerprint_extractor_class_name, current_module="nnunetv2.experiment_planning") for d in dataset_ids: extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean, verbose, show_progress_bar) def plan_experiment_dataset(dataset_id: int, experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner, gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor', overwrite_target_spacing: Optional[Tuple[float, ...]] = None, overwrite_plans_name: Optional[str] = None) -> Tuple[dict, str]: """ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! """ kwargs = {} if overwrite_plans_name is not None: kwargs['plans_name'] = overwrite_plans_name if gpu_memory_target_in_gb is not None: kwargs['gpu_memory_target_in_gb'] = gpu_memory_target_in_gb planner = experiment_planner_class(dataset_id, preprocessor_name=preprocess_class_name, overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if overwrite_target_spacing is not None else overwrite_target_spacing, suppress_transpose=False, # might expose this later, **kwargs ) ret = planner.plan_experiment() return ret, planner.plans_identifier def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner', gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor', overwrite_target_spacing: Optional[Tuple[float, ...]] = None, overwrite_plans_name: Optional[str] = None): """ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! """ if experiment_planner_class_name == 'ExperimentPlanner': print("\n############################\n" "INFO: You are using the old nnU-Net default planner. We have updated our recommendations. " "Please consider using those instead! " "Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md" "\n############################\n") experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), experiment_planner_class_name, current_module="nnunetv2.experiment_planning") plans_identifier = None for d in dataset_ids: _, plans_identifier = plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name, overwrite_target_spacing, overwrite_plans_name) return plans_identifier def preprocess_dataset(dataset_id: int, plans_identifier: str = 'nnUNetPlans', configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), verbose: bool = False, show_progress_bar: bool = True) -> None: if not isinstance(num_processes, list): num_processes = list(num_processes) if len(num_processes) == 1: num_processes = num_processes * len(configurations) if len(num_processes) != len(configurations): raise RuntimeError( f'The list provided with num_processes must either have len 1 or as many elements as there are ' f'configurations (see --help). Number of configurations: {len(configurations)}, length ' f'of num_processes: ' f'{len(num_processes)}') dataset_name = convert_id_to_dataset_name(dataset_id) print(f'Preprocessing dataset {dataset_name}') plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') plans_manager = PlansManager(plans_file) for n, c in zip(num_processes, configurations): print(f'Configuration: {c}...') if c not in plans_manager.available_configurations: print( f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of " f"dataset {dataset_name}. Skipping.") continue configuration_manager = plans_manager.get_configuration(c) print(configuration_manager) preprocessor = configuration_manager.preprocessor_class(verbose=verbose) if hasattr(preprocessor, 'show_progress_bar'): preprocessor.show_progress_bar = show_progress_bar preprocessor.run(dataset_id, c, plans_identifier, num_processes=n) # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no # longer there (useful for compute cluster where only the preprocessed data is available) from distutils.file_util import copy_file maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations')) dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) # only copy files that are newer than the ones already present for k in dataset: copy_file(dataset[k]['label'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']), update=True) def preprocess(dataset_ids: List[int], plans_identifier: str = 'nnUNetPlans', configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), verbose: bool = False, show_progress_bar: bool = True): for d in dataset_ids: preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose, show_progress_bar) ================================================ FILE: nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py ================================================ from nnunetv2.configuration import default_num_processes from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess def _add_logging_args(parser): parser.add_argument('--verbose', required=False, action='store_true', help='Set this to print a lot of stuff. Useful for debugging.') parser.add_argument('--no_pbar', required=False, action='store_true', help='Set this flag to disable the progress bar. Recommended for cluster/HPC environments.') def extract_fingerprint_entry(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', nargs='+', type=int, help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " "planning and preprocessing for these datasets. Can of course also be just one dataset") parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' '\'DatasetFingerprintExtractor\'.') parser.add_argument('-np', type=int, default=default_num_processes, required=False, help=f'[OPTIONAL] Number of processes used for fingerprint extraction. ' f'Default: {default_num_processes}') parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " "each dataset!") parser.add_argument("--clean", required=False, default=False, action="store_true", help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' 'fingerprint already exists, the fingerprint extractor will not run.') _add_logging_args(parser) args, unrecognized_args = parser.parse_known_args() extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose, show_progress_bar=not args.no_pbar) def plan_experiment_entry(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', nargs='+', type=int, help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " "planning and preprocessing for these datasets. Can of course also be just one dataset") parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' 'It\'s an all in one solution now. Wuch. Such amazing.') parser.add_argument('-gpu_memory_target', default=None, type=float, required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner ' 'class default is used). Changing this will ' 'affect patch and batch size and will ' 'definitely affect your models performance! Only use this if you really know what you ' 'are doing and NEVER use this without running the default nnU-Net first as a baseline.') parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' 'models performance! Only use this if you really know what you ' 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' 'configurations. Default: None [no changes]. Changing this will affect image size and ' 'potentially patch and batch ' 'size. This will definitely affect your models performance! Only use this if you really ' 'know what you are doing and NEVER use this without running the default nnU-Net first ' '(as a baseline). Changing the target spacing for the other configurations is currently ' 'not implemented. New target spacing must be a list of three numbers!') parser.add_argument('-overwrite_plans_name', default=None, required=False, help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or ' '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' 'differently named plans file such that the nnunet default plans are not ' 'overwritten. You will then need to specify your custom plans file with -p whenever ' 'running other nnunet commands (training, inference etc)') args, unrecognized_args = parser.parse_known_args() plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) def preprocess_entry(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', nargs='+', type=int, help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " "planning and preprocessing for these datasets. Can of course also be just one dataset") parser.add_argument('-plans_name', default='nnUNetPlans', required=False, help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated') parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres ' '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' 'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.') parser.add_argument('-np', type=int, nargs='+', default=None, required=False, help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " "this number of processes is used for all configurations specified with -c. If it's a " "list of numbers this list must have as many elements as there are configurations. We " "then iterate over zip(configs, num_processes) to determine then umber of processes " "used for each configuration. More processes is always faster (up to the number of " "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " "know what that is then dont touch it, or at least don't increase it!). DANGER: More " "often than not the number of processes that can be used is limited by the amount of " "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") _add_logging_args(parser) args, unrecognized_args = parser.parse_known_args() if args.np is None: default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8} np = [default_np[c] if c in default_np.keys() else 4 for c in args.c] else: np = args.np preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose, show_progress_bar=not args.no_pbar) def plan_and_preprocess_entry(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', nargs='+', type=int, help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " "planning and preprocessing for these datasets. Can of course also be just one dataset") parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' '\'DatasetFingerprintExtractor\'.') parser.add_argument('-npfp', type=int, default=8, required=False, help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8') parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " "each dataset!") parser.add_argument('--no_pp', default=False, action='store_true', required=False, help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no ' 'preprocesing). Useful for debugging.') parser.add_argument("--clean", required=False, default=False, action="store_true", help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' 'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU ' 'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!') parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' 'It\'s an all in one solution now. Wuch. Such amazing.') parser.add_argument('-gpu_memory_target', default=None, type=float, required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner ' 'class default is used). Changing this will ' 'affect patch and batch size and will ' 'definitely affect your models performance! Only use this if you really know what you ' 'are doing and NEVER use this without running the default nnU-Net first as a baseline.') parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' 'models performance! Only use this if you really know what you ' 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' 'configurations. Default: None [no changes]. Changing this will affect image size and ' 'potentially patch and batch ' 'size. This will definitely affect your models performance! Only use this if you really ' 'know what you are doing and NEVER use this without running the default nnU-Net first ' '(as a baseline). Changing the target spacing for the other configurations is currently ' 'not implemented. New target spacing must be a list of three numbers!') parser.add_argument('-overwrite_plans_name', default=None, required=False, help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, ' '-preprocessor_name or ' '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' 'differently named plans file such that the nnunet default plans are not ' 'overwritten. You will then need to specify your custom plans file with -p whenever ' 'running other nnunet commands (training, inference etc)') parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres ' '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' 'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.') parser.add_argument('-np', type=int, nargs='+', default=None, required=False, help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " "this number of processes is used for all configurations specified with -c. If it's a " "list of numbers this list must have as many elements as there are configurations. We " "then iterate over zip(configs, num_processes) to determine then umber of processes " "used for each configuration. More processes is always faster (up to the number of " "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " "know what that is then dont touch it, or at least don't increase it!). DANGER: More " "often than not the number of processes that can be used is limited by the amount of " "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") _add_logging_args(parser) args = parser.parse_args() # fingerprint extraction print("Fingerprint extraction...") extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose, show_progress_bar=not args.no_pbar) # experiment planning print('Experiment planning...') plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) # manage default np if args.np is None: default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8} np = [default_np[c] if c in default_np.keys() else 4 for c in args.c] else: np = args.np # preprocessing if not args.no_pp: print('Preprocessing...') preprocess(args.d, plans_identifier, args.c, np, args.verbose, show_progress_bar=not args.no_pbar) if __name__ == '__main__': plan_and_preprocess_entry() ================================================ FILE: nnunetv2/experiment_planning/plans_for_pretraining/__init__.py ================================================ ================================================ FILE: nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py ================================================ import argparse from typing import Union from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, save_json from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets def move_plans_between_datasets( source_dataset_name_or_id: Union[int, str], target_dataset_name_or_id: Union[int, str], source_plans_identifier: str, target_plans_identifier: str = None): source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id) target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id) if target_plans_identifier is None: target_plans_identifier = source_plans_identifier source_folder = join(nnUNet_preprocessed, source_dataset_name) assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \ f"Run nnUNetv2_plan_and_preprocess for source dataset first!" source_plans_file = join(source_folder, source_plans_identifier + '.json') assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \ f"Expected file: {source_plans_file}" source_plans = load_json(source_plans_file) source_plans['dataset_name'] = target_dataset_name # we need to change data_identifier to use target_plans_identifier if target_plans_identifier != source_plans_identifier: for c in source_plans['configurations'].keys(): if 'data_identifier' in source_plans['configurations'][c].keys(): old_identifier = source_plans['configurations'][c]["data_identifier"] if old_identifier.startswith(source_plans_identifier): new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):] else: new_identifier = target_plans_identifier + '_' + old_identifier source_plans['configurations'][c]["data_identifier"] = new_identifier # we need to change the reader writer class! target_raw_data_dir = join(nnUNet_raw, target_dataset_name) target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json')) # we may need to change the reader/writer # pick any file from the source dataset dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json) example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0] rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True, verbose=False) source_plans["image_reader_writer"] = rw.__name__ if target_plans_identifier is not None: source_plans["plans_name"] = target_plans_identifier save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'), sort_keys=False) def entry_point_move_plans_between_datasets(): parser = argparse.ArgumentParser() parser.add_argument('-s', type=str, required=True, help='Source dataset name or id') parser.add_argument('-t', type=str, required=True, help='Target dataset name or id') parser.add_argument('-sp', type=str, required=True, help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the ' 'identifier would be nnUNetPlans') parser.add_argument('-tp', type=str, required=False, default=None, help='Target plans identifier. Default is None meaning the source plans identifier will ' 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier ' 'such as nnUNetPlans!!!') args = parser.parse_args() move_plans_between_datasets(args.s, args.t, args.sp, args.tp) if __name__ == '__main__': move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2') ================================================ FILE: nnunetv2/experiment_planning/verify_dataset_integrity.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing from typing import Type import numpy as np import pandas as pd from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool: rw = readerclass() seg, properties = rw.read_seg(label_file) found_labels = np.sort(pd.unique(seg.ravel())) # np.unique(seg) unexpected_labels = [i for i in found_labels if i not in expected_labels] if len(found_labels) == 0 and found_labels[0] == 0: print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, ' 'up to you.' % label_file) if len(unexpected_labels) > 0: print("Error: Unexpected labels found in file %s.\nExpected: %s\nFound: %s" % (label_file, expected_labels, found_labels)) return False return True def check_cases(image_files: List[str], label_file: str, expected_num_channels: int, readerclass: Type[BaseReaderWriter]) -> bool: rw = readerclass() ret = True images, properties_image = rw.read_images(image_files) segmentation, properties_seg = rw.read_seg(label_file) # check for nans if np.any(np.isnan(images)): print(f'Images contain NaN pixel values. You need to fix that by ' f'replacing NaN values with something that makes sense for your images!\nImages:\n{image_files}') ret = False if np.any(np.isnan(segmentation)): print(f'Segmentation contains NaN pixel values. You need to fix that.\nSegmentation:\n{label_file}') ret = False # check shapes shape_image = images.shape[1:] shape_seg = segmentation.shape[1:] if shape_image != shape_seg: print('Error: Shape mismatch between segmentation and corresponding images. \nShape images: %s. ' '\nShape seg: %s. \nImage files: %s. \nSeg file: %s\n' % (shape_image, shape_seg, image_files, label_file)) ret = False # check spacings spacing_images = properties_image['spacing'] spacing_seg = properties_seg['spacing'] if not np.allclose(spacing_seg, spacing_images): print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. ' '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' % (spacing_images, spacing_seg, image_files, label_file)) ret = False # check modalities if not len(images) == expected_num_channels: print('Error: Unexpected number of modalities. \nExpected: %d. \nGot: %d. \nImages: %s\n' % (expected_num_channels, len(images), image_files)) ret = False # nibabel checks if 'nibabel_stuff' in properties_image.keys(): # this image was read with NibabelIO affine_image = properties_image['nibabel_stuff']['original_affine'] affine_seg = properties_seg['nibabel_stuff']['original_affine'] if not np.allclose(affine_image, affine_seg): print('WARNING: Affine is not the same for image and seg! \nAffine image: %s \nAffine seg: %s\n' 'Image files: %s. \nSeg file: %s.\nThis can be a problem but doesn\'t have to be. Please run ' 'nnUNetv2_plot_overlay_pngs to verify if everything is OK!\n' % (affine_image, affine_seg, image_files, label_file)) # sitk checks if 'sitk_stuff' in properties_image.keys(): # this image was read with SimpleITKIO # spacing has already been checked, only check direction and origin origin_image = properties_image['sitk_stuff']['origin'] origin_seg = properties_seg['sitk_stuff']['origin'] if not np.allclose(origin_image, origin_seg): print('Warning: Origin mismatch between segmentation and corresponding images. \nOrigin images: %s. ' '\nOrigin seg: %s. \nImage files: %s. \nSeg file: %s\n' % (origin_image, origin_seg, image_files, label_file)) direction_image = properties_image['sitk_stuff']['direction'] direction_seg = properties_seg['sitk_stuff']['direction'] if not np.allclose(direction_image, direction_seg): print('Warning: Direction mismatch between segmentation and corresponding images. \nDirection images: %s. ' '\nDirection seg: %s. \nImage files: %s. \nSeg file: %s\n' % (direction_image, direction_seg, image_files, label_file)) return ret def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None: """ folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json checks if the expected number of training cases and labels are present for each case, if possible, checks whether the pixel grids are aligned checks whether the labels really only contain values they should :param folder: :return: """ assert isfile(join(folder, "dataset.json")), f"There needs to be a dataset.json file in folder, folder={folder}" dataset_json = load_json(join(folder, "dataset.json")) if not 'dataset' in dataset_json.keys(): assert isdir(join(folder, "imagesTr")), f"There needs to be a imagesTr subfolder in folder, folder={folder}" assert isdir(join(folder, "labelsTr")), f"There needs to be a labelsTr subfolder in folder, folder={folder}" # make sure all required keys are there dataset_keys = list(dataset_json.keys()) required_keys = ['labels', "channel_names", "numTraining", "file_ending"] assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \ '\n\nRequired: \n%s\n\nPresent: \n%s\n\nMissing: ' \ '\n%s\n\nUnused by nnU-Net:\n%s' % \ (str(required_keys), str(dataset_keys), str([i for i in required_keys if i not in dataset_keys]), str([i for i in dataset_keys if i not in required_keys])) expected_num_training = dataset_json['numTraining'] num_modalities = len(dataset_json['channel_names'].keys() if 'channel_names' in dataset_json.keys() else dataset_json['modality'].keys()) file_ending = dataset_json['file_ending'] dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) # check if the right number of training cases is present assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \ '(%d). Found %d instead.\nExamples: %s' % \ (expected_num_training, len(dataset), list(dataset.keys())[:5]) # check if corresponding labels are present if 'dataset' in dataset_json.keys(): # just check if everything is there ok = True missing_images = [] missing_labels = [] for k in dataset: for i in dataset[k]['images']: if not isfile(i): missing_images.append(i) ok = False if not isfile(dataset[k]['label']): missing_labels.append(dataset[k]['label']) ok = False if not ok: raise FileNotFoundError(f"Some expected files were missing. Make sure you are properly referencing them " f"in the dataset.json. Or use imagesTr & labelsTr folders!\nMissing images:" f"\n{missing_images}\n\nMissing labels:\n{missing_labels}") else: # old code that uses imagestr and labelstr folders labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False) label_identifiers = [i[:-len(file_ending)] for i in labelfiles] labels_present = [i in label_identifiers for i in dataset.keys()] missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]] assert all(labels_present), f'not all training cases have a label file in labelsTr. Fix that. Missing: {missing}' labelfiles = [v['label'] for v in dataset.values()] image_files = [v['images'] for v in dataset.values()] # no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause # problems anyway label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order')) expected_labels = label_manager.all_labels if label_manager.has_ignore_label: expected_labels.append(label_manager.ignore_label) labels_valid_consecutive = np.ediff1d(expected_labels) == 1 assert all( labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction' # determine reader/writer class reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0]) # check whether only the desired labels are present with multiprocessing.get_context("spawn").Pool(num_processes) as p: result = p.starmap( verify_labels, zip(labelfiles, [reader_writer_class] * len(labelfiles), [expected_labels] * len(labelfiles)) ) if not all(result): raise RuntimeError( 'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).') # check whether shapes and spacings match between images and labels result = p.starmap( check_cases, zip(image_files, labelfiles, [num_modalities] * expected_num_training, [reader_writer_class] * expected_num_training) ) if not all(result): raise RuntimeError( 'Some images have errors. Please check text output above to see which one(s) and what\'s going on.') # check for nans # check all same orientation nibabel print('\n####################') print('verify_dataset_integrity Done. \nIf you didn\'t see any error messages then your dataset is most likely OK!') print('####################\n') if __name__ == "__main__": # investigate geometry issues example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0') num_processes = 6 verify_dataset_integrity(example_folder, num_processes) ================================================ FILE: nnunetv2/imageio/__init__.py ================================================ ================================================ FILE: nnunetv2/imageio/base_reader_writer.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from typing import Tuple, Union, List import numpy as np class BaseReaderWriter(ABC): @staticmethod def _check_all_same(input_list): if len(input_list) == 1: return True else: # compare all entries to the first return np.allclose(input_list[0], input_list[1:]) @staticmethod def _check_all_same_array(input_list): # compare all entries to the first for i in input_list[1:]: if i.shape != input_list[0].shape or not np.allclose(i, input_list[0]): return False return True @abstractmethod def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: """ Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the modalities (or color channels, or however you would like to call them) in its first axis, followed by the spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)). Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg for exporting the predicted segmentations, so make sure you have everything you need in there! IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray. Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y and c the spacing of z. In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224), spacing=(999, 1, 1) For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!) :param image_fnames: :return: 1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image). 2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set a=999 (largest spacing value! Make it larger than b and c) """ pass @abstractmethod def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: """ Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple segmentations are not (yet?) allowed If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))` :param seg_fname: :return: 1) a np.ndarray of shape (1, x, y, z) where x, y, z are the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation). 2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set a=999 (largest spacing value! Make it larger than b and c) """ pass @abstractmethod def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: """ Export the predicted segmentation to the desired file format. The given seg array will have the same shape and orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-) properties is the same dictionary you created during read_images/read_seg so you can use the information here to restore metadata IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape 1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation to 2d via seg = seg[0])! :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)! :param output_fname: :param properties: the dictionary that you created in read_images (the ones this segmentation is based on). Use this to restore metadata :return: """ pass ================================================ FILE: nnunetv2/imageio/natural_image_reader_writer.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union, List import numpy as np from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from skimage import io class NaturalImage2DIO(BaseReaderWriter): """ ONLY SUPPORTS 2D IMAGES!!! """ # there are surely more we could add here. Everything that can be read by skimage.io should be supported supported_file_endings = [ '.png', # '.jpg', # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps! '.bmp', '.tif' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: images = [] for f in image_fnames: npy_img = io.imread(f) if npy_img.ndim == 3: # rgb image, last dimension should be the color channel and the size of that channel should be 3 # (or 4 if we have alpha) assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \ "dimension must have shape 3 or 4 " \ f"(RGB or RGBA). Image shape here is {npy_img.shape}" # move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4 images.append(npy_img.transpose((2, 0, 1))[:, None]) elif npy_img.ndim == 2: # grayscale image images.append(npy_img[None, None]) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)} def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: return self.read_images((seg_fname, )) def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: io.imsave(output_fname, seg[0].astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), check_contrast=False) if __name__ == '__main__': images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',) segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png' imgio = NaturalImage2DIO() img, props = imgio.read_images(images) seg, segprops = imgio.read_seg(segmentation) ================================================ FILE: nnunetv2/imageio/nibabel_reader_writer.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Tuple, Union, List import numpy as np from nibabel.orientations import io_orientation, axcodes2ornt, ornt_transform from nnunetv2.imageio.base_reader_writer import BaseReaderWriter import nibabel from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO class NibabelIO(BaseReaderWriter): """ Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be consistent. This is of course considered properly in segmentation export as well. IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg! """ supported_file_endings = [ '.nii', '.nii.gz', ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: images = [] original_affines = [] spacings_for_nnunet = [] for f in image_fnames: nib_image = nibabel.load(f) assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO' original_affine = nib_image.affine original_affines.append(original_affine) # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) spacings_for_nnunet.append( [float(i) for i in nib_image.header.get_zooms()[::-1]] ) # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. images.append(nib_image.get_fdata().transpose((2, 1, 0))[None]) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same_array(original_affines): print('WARNING! Not all input images have the same original_affines!') print('Affines:') print(original_affines) print('Image files:') print(image_fnames) print( 'It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(spacings_for_nnunet): print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' 'having the same affine') print('spacings_for_nnunet:') print(spacings_for_nnunet) print('Image files:') print(image_fnames) raise RuntimeError() dict = { 'nibabel_stuff': { 'original_affine': original_affines[0], }, 'spacing': spacings_for_nnunet[0] } return np.vstack(images, dtype=np.float32, casting='unsafe'), dict def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: return self.read_images((seg_fname,)) def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: # revert transpose seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False) seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine']) nibabel.save(seg_nib, output_fname) class NibabelIOWithReorient(BaseReaderWriter): """ Reorients images to RAS Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be consistent. This is of course considered properly in segmentation export as well. IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg! """ supported_file_endings = [ '.nii', '.nii.gz', ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: images = [] original_affines = [] reoriented_affines = [] spacings_for_nnunet = [] for f in image_fnames: nib_image = nibabel.load(f) assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO' original_affine = nib_image.affine reoriented_image = nib_image.as_reoriented(io_orientation(original_affine)) reoriented_affine = reoriented_image.affine original_affines.append(original_affine) reoriented_affines.append(reoriented_affine) # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) spacings_for_nnunet.append( [float(i) for i in reoriented_image.header.get_zooms()[::-1]] ) # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None]) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same_array(reoriented_affines): print('WARNING! Not all input images have the same reoriented_affines!') print('Affines:') print(reoriented_affines) print('Image files:') print(image_fnames) print( 'It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(spacings_for_nnunet): print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' 'having the same affine') print('spacings_for_nnunet:') print(spacings_for_nnunet) print('Image files:') print(image_fnames) raise RuntimeError() dict = { 'nibabel_stuff': { 'original_affine': original_affines[0], 'reoriented_affine': reoriented_affines[0], }, 'spacing': spacings_for_nnunet[0] } return np.vstack(images, dtype=np.float32, casting='unsafe'), dict def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: return self.read_images((seg_fname,)) def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: # revert transpose seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False) seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine']) # Solution from https://github.com/nipy/nibabel/issues/1063#issuecomment-967124057 img_ornt = io_orientation(properties['nibabel_stuff']['original_affine']) ras_ornt = axcodes2ornt("RAS") from_canonical = ornt_transform(ras_ornt, img_ornt) seg_nib_reoriented = seg_nib.as_reoriented(from_canonical) if not np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine): print(f'WARNING: Restored affine does not match original affine. File: {output_fname}') print(f'Original affine\n', properties['nibabel_stuff']['original_affine']) print(f'Restored affine\n', seg_nib_reoriented.affine) nibabel.save(seg_nib_reoriented, output_fname) if __name__ == '__main__': img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/imagesTr/case_00004_0000.nii.gz' seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/labelsTr/case_00004.nii.gz' nibio = NibabelIO() # images, dct = nibio.read_images([img_file]) seg, dctseg = nibio.read_seg(seg_file) nibio_r = NibabelIOWithReorient() # images_r, dct_r = nibio_r.read_images([img_file]) seg_r, dctseg_r = nibio_r.read_seg(seg_file) sitkio = SimpleITKIO() # images_sitk, dct_sitk = sitkio.read_images([img_file]) seg_sitk, dctseg_sitk = sitkio.read_seg(seg_file) # write reoriented and original segmentation nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg) nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r) sitkio.write_seg(seg_sitk[0], '/home/isensee/seg_nibio_sitk.nii.gz', dctseg_sitk) # now load all with sitk to make sure no shaped got f'd up a, d1 = sitkio.read_seg('/home/isensee/seg_nibio.nii.gz') b, d2 = sitkio.read_seg('/home/isensee/seg_nibio_r.nii.gz') c, d3 = sitkio.read_seg('/home/isensee/seg_nibio_sitk.nii.gz') assert a.shape == b.shape assert b.shape == c.shape assert np.all(a == b) assert np.all(b == c) ================================================ FILE: nnunetv2/imageio/reader_writer_registry.py ================================================ import traceback from typing import Type from batchgenerators.utilities.file_and_folder_operations import join import nnunetv2 from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO from nnunetv2.imageio.tif_reader_writer import Tiff3DIO from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.utilities.find_class_by_name import recursive_find_python_class LIST_OF_IO_CLASSES = [ NaturalImage2DIO, SimpleITKIO, Tiff3DIO, NibabelIO, NibabelIOWithReorient ] def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None, allow_nonmatching_filename: bool = False, verbose: bool = True ) -> Type[BaseReaderWriter]: if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \ dataset_json_content['overwrite_image_reader_writer'] != 'None': ioclass_name = dataset_json_content['overwrite_image_reader_writer'] # trying to find that class in the nnunetv2.imageio module try: ret = recursive_find_reader_writer_by_name(ioclass_name) if verbose: print(f'Using {ret} reader/writer') return ret except RuntimeError: if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}') if verbose: print('Trying to automatically determine desired class') return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file, allow_nonmatching_filename, verbose) def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False, verbose: bool = True): for rw in LIST_OF_IO_CLASSES: if file_ending.lower() in rw.supported_file_endings: if example_file is not None: # if an example file is provided, try if we can actually read it. If not move on to the next reader try: tmp = rw() _ = tmp.read_images((example_file,)) if verbose: print(f'Using {rw} as reader/writer') return rw except: if verbose: print(f'Failed to open file {example_file} with reader {rw}:') traceback.print_exc() pass else: if verbose: print(f'Using {rw} as reader/writer') return rw else: if allow_nonmatching_filename and example_file is not None: try: tmp = rw() _ = tmp.read_images((example_file,)) if verbose: print(f'Using {rw} as reader/writer') return rw except: if verbose: print(f'Failed to open file {example_file} with reader {rw}:') if verbose: traceback.print_exc() pass raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).") def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]: ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio') if ret is None: raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the " "nnunetv2.imageio module." % rw_class_name) else: return ret ================================================ FILE: nnunetv2/imageio/readme.md ================================================ - Derive your adapter from `BaseReaderWriter`. - Reimplement all abstractmethods. - make sure to support 2d and 3d input images (or raise some error). - place it in this folder or nnU-Net won't find it! - add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py` Bam, you're done! ================================================ FILE: nnunetv2/imageio/simpleitk_reader_writer.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union, List import numpy as np from nnunetv2.imageio.base_reader_writer import BaseReaderWriter import SimpleITK as sitk class SimpleITKIO(BaseReaderWriter): supported_file_endings = [ '.nii.gz', '.nrrd', '.mha', '.gipl' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: images = [] spacings = [] origins = [] directions = [] spacings_for_nnunet = [] for f in image_fnames: itk_image = sitk.ReadImage(f) spacings.append(itk_image.GetSpacing()) origins.append(itk_image.GetOrigin()) directions.append(itk_image.GetDirection()) npy_image = sitk.GetArrayFromImage(itk_image) if npy_image.ndim == 2: # 2d npy_image = npy_image[None, None] max_spacing = max(spacings[-1]) spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) elif npy_image.ndim == 3: # 3d, as in original nnunet npy_image = npy_image[None] spacings_for_nnunet.append(list(spacings[-1])[::-1]) elif npy_image.ndim == 4: # 4d, multiple modalities in one file spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) pass else: raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}") images.append(npy_image) spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(spacings): print('ERROR! Not all input images have the same spacing!') print('Spacings:') print(spacings) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(origins): print('WARNING! Not all input images have the same origin!') print('Origins:') print(origins) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(directions): print('WARNING! Not all input images have the same direction!') print('Directions:') print(directions) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(spacings_for_nnunet): print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' 'bug. Please report!') print('spacings_for_nnunet:') print(spacings_for_nnunet) print('Image files:') print(image_fnames) raise RuntimeError() dict = { 'sitk_stuff': { # this saves the sitk geometry information. This part is NOT used by nnU-Net! 'spacing': spacings[0], 'origin': origins[0], 'direction': directions[0] }, # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays # are returned x,y,z but spacing is returned z,y,x. Duh. 'spacing': spacings_for_nnunet[0] } return np.vstack(images, dtype=np.float32, casting='unsafe'), dict def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: return self.read_images((seg_fname, )) def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' output_dimension = len(properties['sitk_stuff']['spacing']) assert 1 < output_dimension < 4 if output_dimension == 2: seg = seg[0] itk_image = sitk.GetImageFromArray(seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False)) itk_image.SetSpacing(properties['sitk_stuff']['spacing']) itk_image.SetOrigin(properties['sitk_stuff']['origin']) itk_image.SetDirection(properties['sitk_stuff']['direction']) sitk.WriteImage(itk_image, output_fname, True) class SimpleITKIOWithReorient(SimpleITKIO): def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]], orientation = "RAS") -> Tuple[np.ndarray, dict]: images = [] spacings = [] origins = [] directions = [] spacings_for_nnunet = [] for f in image_fnames: itk_image = sitk.ReadImage(f) original_orientation = sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection()) itk_image = sitk.DICOMOrient(itk_image, orientation) # print(sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection())) spacings.append(itk_image.GetSpacing()) origins.append(itk_image.GetOrigin()) directions.append(itk_image.GetDirection()) npy_image = sitk.GetArrayFromImage(itk_image) if npy_image.ndim == 2: # 2d npy_image = npy_image[None, None] max_spacing = max(spacings[-1]) spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) elif npy_image.ndim == 3: # 3d, as in original nnunet npy_image = npy_image[None] spacings_for_nnunet.append(list(spacings[-1])[::-1]) elif npy_image.ndim == 4: # 4d, multiple modalities in one file spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) pass else: raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}") images.append(npy_image) spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(spacings): print('ERROR! Not all input images have the same spacing!') print('Spacings:') print(spacings) print('Image files:') print(image_fnames) raise RuntimeError() if not self._check_all_same(origins): print('WARNING! Not all input images have the same origin!') print('Origins:') print(origins) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(directions): print('WARNING! Not all input images have the same direction!') print('Directions:') print(directions) print('Image files:') print(image_fnames) print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify ' 'that segmentations and data overlap.') if not self._check_all_same(spacings_for_nnunet): print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' 'bug. Please report!') print('spacings_for_nnunet:') print(spacings_for_nnunet) print('Image files:') print(image_fnames) raise RuntimeError() dict = { 'sitk_stuff': { # this saves the sitk geometry information. This part is NOT used by nnU-Net! 'spacing': spacings[0], 'origin': origins[0], 'direction': directions[0], 'original_orientation': original_orientation }, # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays # are returned x,y,z but spacing is returned z,y,x. Duh. 'spacing': spacings_for_nnunet[0] } return np.vstack(images, dtype=np.float32, casting='unsafe'), dict def write_seg(self, seg, output_fname, properties): assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' output_dimension = len(properties['sitk_stuff']['spacing']) assert 1 < output_dimension < 4 if output_dimension == 2: seg = seg[0] itk_image = sitk.GetImageFromArray(seg.astype(np.uint8, copy=False)) itk_image.SetSpacing(properties['sitk_stuff']['spacing']) itk_image.SetOrigin(properties['sitk_stuff']['origin']) itk_image.SetDirection(properties['sitk_stuff']['direction']) itk_image = sitk.DICOMOrient(itk_image, properties['sitk_stuff']['original_orientation']) sitk.WriteImage(itk_image, output_fname, True) ================================================ FILE: nnunetv2/imageio/tif_reader_writer.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os.path from typing import Tuple, Union, List import numpy as np from nnunetv2.imageio.base_reader_writer import BaseReaderWriter import tifffile from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join class Tiff3DIO(BaseReaderWriter): """ reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)! If you have 2D tiffs, use NaturalImage2DIO Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is expected to be image1.json)! """ supported_file_endings = [ '.tif', '.tiff', ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: # figure out file ending used here ending = '.' + image_fnames[0].split('.')[-1] assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' ending_length = len(ending) truncate_length = ending_length + 5 # 5 comes from len(_0000) images = [] for f in image_fnames: image = tifffile.imread(f) if image.ndim != 3: raise RuntimeError(f"Only 3D images are supported! File: {f}") images.append(image[None]) # see if aux file can be found expected_aux_file = image_fnames[0][:-truncate_length] + '.json' if isfile(expected_aux_file): spacing = load_json(expected_aux_file)['spacing'] assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' else: print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).') spacing = (1, 1, 1) if not self._check_all_same([i.shape for i in images]): print('ERROR! Not all input images have the same shape!') print('Shapes:') print([i.shape for i in images]) print('Image files:') print(image_fnames) raise RuntimeError() return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': spacing} def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha tifffile.imwrite(output_fname, data=seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), compression='zlib') file = os.path.basename(output_fname) out_dir = os.path.dirname(output_fname) ending = file.split('.')[-1] save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json')) def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: # figure out file ending used here ending = '.' + seg_fname.split('.')[-1] assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' ending_length = len(ending) seg = tifffile.imread(seg_fname) if seg.ndim != 3: raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}") seg = seg[None] # see if aux file can be found expected_aux_file = seg_fname[:-ending_length] + '.json' if isfile(expected_aux_file): spacing = load_json(expected_aux_file)['spacing'] assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}' assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}" else: print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).') spacing = (1, 1, 1) return seg.astype(np.float32, copy=False), {'spacing': spacing} ================================================ FILE: nnunetv2/inference/JHU_inference.py ================================================ import argparse import multiprocessing import os from time import sleep from typing import Union import numpy as np import torch from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle, join, maybe_mkdir_p, subdirs from nnunetv2.configuration import default_num_processes from nnunetv2.inference.export_prediction import convert_predicted_logits_to_segmentation_with_correct_shape from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.inference.sliding_window_prediction import compute_gaussian from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy from nnunetv2.utilities.helpers import empty_cache from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager def export_prediction_from_logits_singleFiles( predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, configuration_manager: ConfigurationManager, plans_manager: PlansManager, dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, save_probabilities: bool = False): """ This function generates the output structure expected by the JHU benchmark. We interpret output_file_truncated as the output folder. We create 'predictions' subfolders and populate them with the label maps """ if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) ret = convert_predicted_logits_to_segmentation_with_correct_shape( predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, return_probabilities=save_probabilities ) del predicted_array_or_file # save if save_probabilities: segmentation_final, probabilities_final = ret np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) save_pickle(properties_dict, output_file_truncated + '.pkl') del probabilities_final, ret else: segmentation_final = ret del ret rw = plans_manager.image_reader_writer_class() output_folder = join(output_file_truncated, 'predictions') maybe_mkdir_p(output_folder) label_name_dict = {j: i for i, j in label_manager.label_dict.items()} for l in label_manager.foreground_labels: label_name = label_name_dict[l] rw.write_seg( (segmentation_final == l).astype(np.uint8, copy=False), join(output_folder, label_name + dataset_json_dict_or_file['file_ending']), properties_dict ) class JHUPredictor(nnUNetPredictor): def predict_from_data_iterator(self, data_iterator, save_probabilities: bool = False, num_processes_segmentation_export: int = default_num_processes): """ We replace export_prediction_from_logits with export_prediction_from_logits_singleFiles to comply with JHU benchmark output format expectations """ with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: worker_list = [i for i in export_pool._pool] r = [] for preprocessed in data_iterator: data = preprocessed['data'] if isinstance(data, str): delfile = data data = torch.from_numpy(np.load(data)) os.remove(delfile) ofile = preprocessed['ofile'] if ofile is not None: print(f'\nPredicting {os.path.basename(ofile)}:') else: print(f'\nPredicting image of shape {data.shape}:') print(f'perform_everything_on_device: {self.perform_everything_on_device}') properties = preprocessed['data_properties'] # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with # npy files proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) while not proceed: sleep(0.1) proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) prediction = self.predict_logits_from_preprocessed_data(data).cpu() if ofile is not None: # this needs to go into background processes # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, # self.dataset_json, ofile, save_probabilities) print('sending off prediction to background worker for resampling and export') r.append( export_pool.starmap_async( export_prediction_from_logits_singleFiles, ((prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile, save_probabilities),) ) ) else: # convert_predicted_logits_to_segmentation_with_correct_shape( # prediction, self.plans_manager, # self.configuration_manager, self.label_manager, # properties, # save_probabilities) print('sending off prediction to background worker for resampling') r.append( export_pool.starmap_async( convert_predicted_logits_to_segmentation_with_correct_shape, ( (prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties, save_probabilities),) ) ) if ofile is not None: print(f'done with {os.path.basename(ofile)}') else: print(f'\nDone with image of shape {data.shape}:') ret = [i.get()[0] for i in r] if isinstance(data_iterator, MultiThreadedAugmenter): data_iterator._finish() # clear lru cache compute_gaussian.cache_clear() # clear device cache empty_cache(self.device) return ret if __name__ == '__main__': # python nnunetv2/inference/JHU_inference.py /home/isensee/Downloads/AbdomenAtlasTest /home/isensee/Downloads/AbdomenAtlasTest_pred -model /home/isensee/temp/JHU/trained_model_ep3850 # /home/isensee/temp/JHU/trained_model_ep3850 # /home/isensee/Downloads/AbdomenAtlasTest # /home/isensee/Downloads/AbdomenAtlasTest_pred os.environ['nnUNet_compile'] = 'f' parser = argparse.ArgumentParser() parser.add_argument('input_dir', type=str) parser.add_argument('output_dir', type=str) parser.add_argument('-model', required=True, type=str) parser.add_argument('--disable_tqdm', required=False, action='store_true', default=False) args = parser.parse_args() predictor = JHUPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=not args.disable_tqdm ) predictor.initialize_from_trained_model_folder( args.model, ('all', ), 'checkpoint_final.pth' ) # we need to create list of list of input files input_caseids = subdirs(args.input_dir, join=False) input_files = [[join(args.input_dir, i, 'ct.nii.gz')] for i in input_caseids] output_folders = [join(args.output_dir, i) for i in input_caseids] predictor.predict_from_files( input_files, output_folders, save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=3, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0 ) ================================================ FILE: nnunetv2/inference/__init__.py ================================================ ================================================ FILE: nnunetv2/inference/data_iterators.py ================================================ import multiprocessing import queue from torch.multiprocessing import Event, Queue, Manager from time import sleep from typing import Union, List import numpy as np import torch from batchgenerators.dataloading.data_loader import DataLoader from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], output_filenames_truncated: Union[None, List[str]], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, target_queue: Queue, done_event: Event, abort_event: Event, verbose: bool = False): try: label_manager = plans_manager.get_label_manager(dataset_json) preprocessor = configuration_manager.preprocessor_class(verbose=verbose) for idx in range(len(list_of_lists)): data, seg, data_properties = preprocessor.run_case(list_of_lists[idx], list_of_segs_from_prev_stage_files[ idx] if list_of_segs_from_prev_stage_files is not None else None, plans_manager, configuration_manager, dataset_json) if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) item = {'data': data, 'data_properties': data_properties, 'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None} success = False while not success: try: if abort_event.is_set(): return target_queue.put(item, timeout=0.01) success = True except queue.Full: pass done_event.set() except Exception as e: # print(Exception, e) abort_event.set() raise e def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], output_filenames_truncated: Union[None, List[str]], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, num_processes: int, pin_memory: bool = False, verbose: bool = False): context = multiprocessing.get_context('spawn') manager = Manager() num_processes = min(len(list_of_lists), num_processes) assert num_processes >= 1 processes = [] done_events = [] target_queues = [] abort_event = manager.Event() for i in range(num_processes): event = manager.Event() queue = manager.Queue(maxsize=1) pr = context.Process(target=preprocess_fromfiles_save_to_queue, args=( list_of_lists[i::num_processes], list_of_segs_from_prev_stage_files[ i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, output_filenames_truncated[ i::num_processes] if output_filenames_truncated is not None else None, plans_manager, dataset_json, configuration_manager, queue, event, abort_event, verbose ), daemon=True) pr.start() target_queues.append(queue) done_events.append(event) processes.append(pr) worker_ctr = 0 while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): # import IPython;IPython.embed() if not target_queues[worker_ctr].empty(): item = target_queues[worker_ctr].get() worker_ctr = (worker_ctr + 1) % num_processes else: all_ok = all( [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() if not all_ok: raise RuntimeError('Background workers died. Look for the error message further up! If there is ' 'none then your RAM was full and the worker was killed by the OS. Use fewer ' 'workers or get more RAM in that case!') sleep(0.01) continue if pin_memory: [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] yield item [p.join() for p in processes] class PreprocessAdapter(DataLoader): def __init__(self, list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], preprocessor: DefaultPreprocessor, output_filenames_truncated: Union[None, List[str]], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, num_threads_in_multithreaded: int = 1): self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \ preprocessor, plans_manager, configuration_manager, dataset_json self.label_manager = plans_manager.get_label_manager(dataset_json) if list_of_segs_from_prev_stage_files is None: list_of_segs_from_prev_stage_files = [None] * len(list_of_lists) if output_filenames_truncated is None: output_filenames_truncated = [None] * len(list_of_lists) super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)), 1, num_threads_in_multithreaded, seed_for_shuffle=1, return_incomplete=True, shuffle=False, infinite=False, sampling_probabilities=None) self.indices = list(range(len(list_of_lists))) def generate_train_batch(self): idx = self.get_indices()[0] files, seg_prev_stage, ofile = self._data[idx] # if we have a segmentation from the previous stage we have to process it together with the images so that we # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after # preprocessing and then there might be misalignments data, seg, data_properties = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager, self.configuration_manager, self.dataset_json) if seg_prev_stage is not None: seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) data = torch.from_numpy(data) return {'data': data, 'data_properties': data_properties, 'ofile': ofile} class PreprocessAdapterFromNpy(DataLoader): def __init__(self, list_of_images: List[np.ndarray], list_of_segs_from_prev_stage: Union[List[np.ndarray], None], list_of_image_properties: List[dict], truncated_ofnames: Union[List[str], None], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, num_threads_in_multithreaded: int = 1, verbose: bool = False): preprocessor = configuration_manager.preprocessor_class(verbose=verbose) self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \ preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames self.label_manager = plans_manager.get_label_manager(dataset_json) if list_of_segs_from_prev_stage is None: list_of_segs_from_prev_stage = [None] * len(list_of_images) if truncated_ofnames is None: truncated_ofnames = [None] * len(list_of_images) super().__init__( list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)), 1, num_threads_in_multithreaded, seed_for_shuffle=1, return_incomplete=True, shuffle=False, infinite=False, sampling_probabilities=None) self.indices = list(range(len(list_of_images))) def generate_train_batch(self): idx = self.get_indices()[0] image, seg_prev_stage, props, ofname = self._data[idx] # if we have a segmentation from the previous stage we have to process it together with the images so that we # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after # preprocessing and then there might be misalignments data, seg, props = self.preprocessor.run_case_npy(image, seg_prev_stage, props, self.plans_manager, self.configuration_manager, self.dataset_json) if seg_prev_stage is not None: seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) data = torch.from_numpy(data) return {'data': data, 'data_properties': props, 'ofile': ofname} def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], list_of_segs_from_prev_stage: Union[List[np.ndarray], None], list_of_image_properties: List[dict], truncated_ofnames: Union[List[str], None], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, target_queue: Queue, done_event: Event, abort_event: Event, verbose: bool = False): try: label_manager = plans_manager.get_label_manager(dataset_json) preprocessor = configuration_manager.preprocessor_class(verbose=verbose) for idx in range(len(list_of_images)): data, seg, props = preprocessor.run_case_npy(list_of_images[idx], list_of_segs_from_prev_stage[ idx] if list_of_segs_from_prev_stage is not None else None, list_of_image_properties[idx], plans_manager, configuration_manager, dataset_json) list_of_image_properties[idx] = props if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None: seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) item = {'data': data, 'data_properties': list_of_image_properties[idx], 'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None} success = False while not success: try: if abort_event.is_set(): return target_queue.put(item, timeout=0.01) success = True except queue.Full: pass done_event.set() except Exception as e: abort_event.set() raise e def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], list_of_segs_from_prev_stage: Union[List[np.ndarray], None], list_of_image_properties: List[dict], truncated_ofnames: Union[List[str], None], plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, num_processes: int, pin_memory: bool = False, verbose: bool = False): context = multiprocessing.get_context('spawn') manager = Manager() num_processes = min(len(list_of_images), num_processes) assert num_processes >= 1 target_queues = [] processes = [] done_events = [] abort_event = manager.Event() for i in range(num_processes): event = manager.Event() queue = manager.Queue(maxsize=1) pr = context.Process(target=preprocess_fromnpy_save_to_queue, args=( list_of_images[i::num_processes], list_of_segs_from_prev_stage[ i::num_processes] if list_of_segs_from_prev_stage is not None else None, list_of_image_properties[i::num_processes], truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, plans_manager, dataset_json, configuration_manager, queue, event, abort_event, verbose ), daemon=True) pr.start() done_events.append(event) processes.append(pr) target_queues.append(queue) worker_ctr = 0 while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): if not target_queues[worker_ctr].empty(): item = target_queues[worker_ctr].get() worker_ctr = (worker_ctr + 1) % num_processes else: all_ok = all( [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() if not all_ok: raise RuntimeError('Background workers died. Look for the error message further up! If there is ' 'none then your RAM was full and the worker was killed by the OS. Use fewer ' 'workers or get more RAM in that case!') sleep(0.01) continue if pin_memory: [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] yield item [p.join() for p in processes] ================================================ FILE: nnunetv2/inference/examples.py ================================================ if __name__ == '__main__': from nnunetv2.paths import nnUNet_results, nnUNet_raw import torch from batchgenerators.utilities.file_and_folder_operations import join from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO # nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction # instantiate the nnUNetPredictor predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=True ) # initializes the network architecture, loads the checkpoint predictor.initialize_from_trained_model_folder( join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) # variant 1: give input and output folders predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), save_probabilities=False, overwrite=False, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) # variant 2, use list of files as inputs. Note how we use nested lists!!! indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], [join(indir, 'liver_142_0000.nii.gz')]], [join(outdir, 'liver_152.nii.gz'), join(outdir, 'liver_142.nii.gz')], save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) # variant 2.5, returns segmentations indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], [join(indir, 'liver_142_0000.nii.gz')]], None, save_probabilities=True, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) # predict several npy images from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) # we do not set output files so that the segmentations will be returned. You can of course also specify output # files instead (no return value on that case) ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], None, [props, props2, props3, props4], None, 2, save_probabilities=False, num_processes_segmentation_export=2) # predict a single numpy array img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) ret = predictor.predict_single_npy_array(img, props, None, None, True) # custom iterator img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! # If 'ofile' is None, the result will be returned instead of written to a file # the iterator is responsible for performing the correct preprocessing! # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays # (they both use predictor.predict_from_data_iterator) for inspiration! def my_iterator(list_of_input_arrs, list_of_input_props): preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) for a, p in zip(list_of_input_arrs, list_of_input_props): data, seg, p = preprocessor.run_case_npy(a, None, p, predictor.plans_manager, predictor.configuration_manager, predictor.dataset_json) yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None} ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), save_probabilities=False, num_processes_segmentation_export=3) ================================================ FILE: nnunetv2/inference/export_prediction.py ================================================ from typing import Union, List import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import insert_crop_into_image from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle from nnunetv2.configuration import default_num_processes from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2 from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], plans_manager: PlansManager, configuration_manager: ConfigurationManager, label_manager: LabelManager, properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = default_num_processes): old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) # resample to original shape spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward] current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]) # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because # apply_inference_nonlin will convert to torch if not return_probabilities: # this has a faster computation path becasue we can skip the softmax in regular (not region based) trainig segmentation = label_manager.convert_logits_to_segmentation(predicted_logits) else: predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) del predicted_logits # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) segmentation_reverted_cropping = insert_crop_into_image(segmentation_reverted_cropping, segmentation, properties_dict['bbox_used_for_cropping']) del segmentation # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation_reverted_cropping, torch.Tensor): segmentation_reverted_cropping = segmentation_reverted_cropping.cpu().numpy() # revert transpose segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) if return_probabilities: # revert cropping predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, properties_dict[ 'bbox_used_for_cropping'], properties_dict[ 'shape_before_cropping']) predicted_probabilities = predicted_probabilities.cpu().numpy() # revert transpose predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in plans_manager.transpose_backward]) torch.set_num_threads(old_threads) return segmentation_reverted_cropping, predicted_probabilities else: torch.set_num_threads(old_threads) return segmentation_reverted_cropping def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, configuration_manager: ConfigurationManager, plans_manager: PlansManager, dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, save_probabilities: bool = False, num_threads_torch: int = default_num_processes): # if isinstance(predicted_array_or_file, str): # tmp = deepcopy(predicted_array_or_file) # if predicted_array_or_file.endswith('.npy'): # predicted_array_or_file = np.load(predicted_array_or_file) # elif predicted_array_or_file.endswith('.npz'): # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] # os.remove(tmp) if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) ret = convert_predicted_logits_to_segmentation_with_correct_shape( predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, return_probabilities=save_probabilities, num_threads_torch=num_threads_torch ) del predicted_array_or_file # save if save_probabilities: segmentation_final, probabilities_final = ret np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) save_pickle(properties_dict, output_file_truncated + '.pkl') del probabilities_final, ret else: segmentation_final = ret del ret rw = plans_manager.image_reader_writer_class() rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], properties_dict) def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes, dataset_class=None) \ -> None: old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward] # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, target_shape, current_spacing, target_spacing) # create segmentation (argmax, regions, etc) label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() if dataset_class is None: nnUNetDatasetBlosc2.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file) else: dataset_class.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file) torch.set_num_threads(old_threads) ================================================ FILE: nnunetv2/inference/predict_from_raw_data.py ================================================ import inspect import itertools import multiprocessing import os from copy import deepcopy from queue import Queue from threading import Thread from time import sleep from typing import Tuple, Union, List, Optional import numpy as np import torch from acvl_utils.cropping_and_padding.padding import pad_nd_image from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ save_json from torch import nn from torch._dynamo import OptimizedModule from torch.nn.parallel import DistributedDataParallel from tqdm import tqdm import nnunetv2 from nnunetv2.configuration import default_num_processes from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \ preprocessing_iterator_fromnpy from nnunetv2.inference.export_prediction import export_prediction_from_logits, \ convert_predicted_logits_to_segmentation_with_correct_shape from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ compute_steps_for_sliding_window from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder class nnUNetPredictor(object): def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, verbose_preprocessing: bool = False, allow_tqdm: bool = True): self.verbose = verbose self.verbose_preprocessing = verbose_preprocessing self.allow_tqdm = allow_tqdm self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \ self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None self.tile_step_size = tile_step_size self.use_gaussian = use_gaussian self.use_mirroring = use_mirroring if device.type == 'cuda': torch.backends.cudnn.benchmark = True else: print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') perform_everything_on_device = False self.device = device self.perform_everything_on_device = perform_everything_on_device def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], checkpoint_name: str = 'checkpoint_final.pth'): """ This is used when making predictions with a trained model """ if use_folds is None: use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) plans = load_json(join(model_training_output_dir, 'plans.json')) plans_manager = PlansManager(plans) if isinstance(use_folds, str): use_folds = [use_folds] parameters = [] for i, f in enumerate(use_folds): f = int(f) if f != 'all' else f checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), map_location=torch.device('cpu'), weights_only=False) if i == 0: trainer_name = checkpoint['trainer_name'] configuration_name = checkpoint['init_args']['configuration'] inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None parameters.append(checkpoint['network_weights']) configuration_manager = plans_manager.get_configuration(configuration_name) # restore network num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') if trainer_class is None: raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, configuration_manager.network_arch_init_kwargs_req_import, num_input_channels, plans_manager.get_label_manager(dataset_json).num_segmentation_heads, enable_deep_supervision=False ) self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters # initialize network with first set of parameters, also see https://github.com/MIC-DKFZ/nnUNet/issues/2520 network.load_state_dict(parameters[0]) self.network = network self.dataset_json = dataset_json self.trainer_name = trainer_name self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): print('Using torch.compile') self.network = torch.compile(self.network) def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], dataset_json: dict, trainer_name: str, inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]): """ This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation """ self.plans_manager = plans_manager self.configuration_manager = configuration_manager self.list_of_parameters = parameters self.network = network self.dataset_json = dataset_json self.trainer_name = trainer_name self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) allow_compile = True allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and ( os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) if isinstance(self.network, DistributedDataParallel): allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) if allow_compile: print('Using torch.compile') self.network = torch.compile(self.network) @staticmethod def auto_detect_available_folds(model_training_output_dir, checkpoint_name): print('use_folds is None, attempting to auto detect available folds') fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False) fold_folders = [i for i in fold_folders if i != 'fold_all'] fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))] use_folds = [int(i.split('_')[-1]) for i in fold_folders] print(f'found the following folds: {use_folds}') return use_folds def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]], folder_with_segs_from_prev_stage: str = None, overwrite: bool = True, part_id: int = 0, num_parts: int = 1, save_probabilities: bool = False): if isinstance(list_of_lists_or_source_folder, str): list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder, self.dataset_json['file_ending']) print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder') list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts] caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in list_of_lists_or_source_folder] print( f'I am processing {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)') print(f'There are {len(caseids)} cases that I would like to predict') if isinstance(output_folder_or_list_of_truncated_output_files, str): output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids] elif isinstance(output_folder_or_list_of_truncated_output_files, list): output_filename_truncated = output_folder_or_list_of_truncated_output_files[part_id::num_parts] else: output_filename_truncated = None seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if folder_with_segs_from_prev_stage is not None else None for i in caseids] # remove already predicted files from the lists if not overwrite and output_filename_truncated is not None: tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated] if save_probabilities: tmp2 = [isfile(i + '.npz') for i in output_filename_truncated] tmp = [i and j for i, j in zip(tmp, tmp2)] not_existing_indices = [i for i, j in enumerate(tmp) if not j] output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices] list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices] seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices] print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. ' f'That\'s {len(not_existing_indices)} cases.') return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files def predict_from_files(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], save_probabilities: bool = False, overwrite: bool = True, num_processes_preprocessing: int = default_num_processes, num_processes_segmentation_export: int = default_num_processes, folder_with_segs_from_prev_stage: str = None, num_parts: int = 1, part_id: int = 0): """ This is nnU-Net's default function for making predictions. It works best for batch predictions (predicting many images at once). """ assert part_id <= num_parts, ("Part ID must be smaller than num_parts. Remember that we start counting with 0. " "So if there are 3 parts then valid part IDs are 0, 1, 2") if isinstance(output_folder_or_list_of_truncated_output_files, str): output_folder = output_folder_or_list_of_truncated_output_files elif isinstance(output_folder_or_list_of_truncated_output_files, list): output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) else: output_folder = None ######################## # let's store the input arguments so that its clear what was used to generate the prediction if output_folder is not None: my_init_kwargs = {} for k in inspect.signature(self.predict_from_files).parameters.keys(): my_init_kwargs[k] = locals()[k] my_init_kwargs = deepcopy( my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a recursive_fix_for_json_export(my_init_kwargs) maybe_mkdir_p(output_folder) save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) # we need these two if we want to do things with the predictions like for example apply postprocessing save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) ####################### # check if we need a prediction from the previous stage if self.configuration_manager.previous_stage_name is not None: assert folder_with_segs_from_prev_stage is not None, \ f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ f' they are located via folder_with_segs_from_prev_stage' # sort out input and output filenames list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ self._manage_input_and_output_lists(list_of_lists_or_source_folder, output_folder_or_list_of_truncated_output_files, folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, save_probabilities) if len(list_of_lists_or_source_folder) == 0: return data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, seg_from_prev_stage_files, output_filename_truncated, num_processes_preprocessing) return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) def _internal_get_data_iterator_from_lists_of_filenames(self, input_list_of_lists: List[List[str]], seg_from_prev_stage_files: Union[List[str], None], output_filenames_truncated: Union[List[str], None], num_processes: int): return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files, output_filenames_truncated, self.plans_manager, self.dataset_json, self.configuration_manager, num_processes, self.device.type == 'cuda', self.verbose_preprocessing) # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing) # # hijack batchgenerators, yo # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This # # way we don't have to reinvent the wheel here. # num_processes = max(1, min(num_processes, len(input_list_of_lists))) # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor, # output_filenames_truncated, self.plans_manager, self.dataset_json, # self.configuration_manager, num_processes) # if num_processes == 0: # mta = SingleThreadedAugmenter(ppa, None) # else: # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory) # return mta def get_data_iterator_from_raw_npy_data(self, image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, np.ndarray, List[ np.ndarray]], properties_or_list_of_properties: Union[dict, List[dict]], truncated_ofname: Union[str, List[str], None], num_processes: int = 3): list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \ image_or_list_of_images if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray): segs_from_prev_stage_or_list_of_segs_from_prev_stage = [ segs_from_prev_stage_or_list_of_segs_from_prev_stage] if isinstance(truncated_ofname, str): truncated_ofname = [truncated_ofname] if isinstance(properties_or_list_of_properties, dict): properties_or_list_of_properties = [properties_or_list_of_properties] num_processes = min(num_processes, len(list_of_images)) pp = preprocessing_iterator_fromnpy( list_of_images, segs_from_prev_stage_or_list_of_segs_from_prev_stage, properties_or_list_of_properties, truncated_ofname, self.plans_manager, self.dataset_json, self.configuration_manager, num_processes, self.device.type == 'cuda', self.verbose_preprocessing ) return pp def predict_from_list_of_npy_arrays(self, image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, np.ndarray, List[ np.ndarray]], properties_or_list_of_properties: Union[dict, List[dict]], truncated_ofname: Union[str, List[str], None], num_processes: int = 3, save_probabilities: bool = False, num_processes_segmentation_export: int = default_num_processes): iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images, segs_from_prev_stage_or_list_of_segs_from_prev_stage, properties_or_list_of_properties, truncated_ofname, num_processes) return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export) def predict_from_data_iterator(self, data_iterator, save_probabilities: bool = False, num_processes_segmentation_export: int = default_num_processes): """ each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! If 'ofile' is None, the result will be returned instead of written to a file """ with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: worker_list = [i for i in export_pool._pool] r = [] for preprocessed in data_iterator: data = preprocessed['data'] if isinstance(data, str): delfile = data data = torch.from_numpy(np.load(data)) os.remove(delfile) ofile = preprocessed['ofile'] if ofile is not None: print(f'\nPredicting {os.path.basename(ofile)}:') else: print(f'\nPredicting image of shape {data.shape}:') print(f'perform_everything_on_device: {self.perform_everything_on_device}') properties = preprocessed['data_properties'] # let's not get into a runaway situation where the GPU predicts so fast that the disk has to be swamped with # npy files proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) while not proceed: sleep(0.1) proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) # convert to numpy to prevent uncatchable memory alignment errors from multiprocessing serialization of torch tensors prediction = self.predict_logits_from_preprocessed_data(data).cpu().detach().numpy() if ofile is not None: print('sending off prediction to background worker for resampling and export') r.append( export_pool.starmap_async( export_prediction_from_logits, ((prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile, save_probabilities),) ) ) else: print('sending off prediction to background worker for resampling') r.append( export_pool.starmap_async( convert_predicted_logits_to_segmentation_with_correct_shape, ( (prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties, save_probabilities),) ) ) if ofile is not None: print(f'done with {os.path.basename(ofile)}') else: print(f'\nDone with image of shape {data.shape}:') ret = [i.get()[0] for i in r] if isinstance(data_iterator, MultiThreadedAugmenter): data_iterator._finish() # clear lru cache compute_gaussian.cache_clear() # clear device cache empty_cache(self.device) return ret def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, segmentation_previous_stage: np.ndarray = None, output_file_truncated: str = None, save_or_return_probabilities: bool = False): """ WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON. input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis ordering which cannot be disturbed in inference, otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class for loading images as was used during nnU-Net preprocessing! You can find that class in your plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the default axis ordering for medical images is the one from SimpleITK. If you load with nibabel, you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]! image_properties must only have a 'spacing' key! """ ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], [output_file_truncated], self.plans_manager, self.dataset_json, self.configuration_manager, num_threads_in_multithreaded=1, verbose=self.verbose) if self.verbose: print('preprocessing') dct = next(ppa) if self.verbose: print('predicting') predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() if self.verbose: print('resampling to original shape') if output_file_truncated is not None: export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager, self.plans_manager, self.dataset_json, output_file_truncated, save_or_return_probabilities) else: ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, self.configuration_manager, self.label_manager, dct['data_properties'], return_probabilities= save_or_return_probabilities) if save_or_return_probabilities: return ret[0], ret[1] else: return ret @torch.inference_mode() def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: """ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. SEE convert_predicted_logits_to_segmentation_with_correct_shape """ n_threads = torch.get_num_threads() torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) prediction = None for params in self.list_of_parameters: # messing with state dict names... if not isinstance(self.network, OptimizedModule): self.network.load_state_dict(params) else: self.network._orig_mod.load_state_dict(params) # why not leave prediction on device if perform_everything_on_device? Because this may cause the # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than # this actually saves computation time if prediction is None: prediction = self.predict_sliding_window_return_logits(data).to('cpu') else: prediction += self.predict_sliding_window_return_logits(data).to('cpu') if len(self.list_of_parameters) > 1: prediction /= len(self.list_of_parameters) if self.verbose: print('Prediction done') torch.set_num_threads(n_threads) return prediction def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): slicers = [] if len(self.configuration_manager.patch_size) < len(image_size): assert len(self.configuration_manager.patch_size) == len( image_size) - 1, 'if tile_size has less entries than image_size, ' \ 'len(tile_size) ' \ 'must be one shorter than len(image_size) ' \ '(only dimension ' \ 'discrepancy of 1 allowed).' steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size, self.tile_step_size) if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is' f' {image_size}, tile_size {self.configuration_manager.patch_size}, ' f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') for d in range(image_size[0]): for sx in steps[0]: for sy in steps[1]: slicers.append( tuple([slice(None), d, *[slice(si, si + ti) for si, ti in zip((sx, sy), self.configuration_manager.patch_size)]])) else: steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size, self.tile_step_size) if self.verbose: print( f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, ' f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') for sx in steps[0]: for sy in steps[1]: for sz in steps[2]: slicers.append( tuple([slice(None), *[slice(si, si + ti) for si, ti in zip((sx, sy, sz), self.configuration_manager.patch_size)]])) return slicers @torch.inference_mode() def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None prediction = self.network(x) if mirror_axes is not None: # check for invalid numbers in mirror_axes # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' mirror_axes = [m + 2 for m in mirror_axes] axes_combinations = [ c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) ] for axes in axes_combinations: prediction += torch.flip(self.network(torch.flip(x, axes)), axes) prediction /= (len(axes_combinations) + 1) return prediction @torch.inference_mode() def _internal_predict_sliding_window_return_logits(self, data: torch.Tensor, slicers, do_on_device: bool = True, ): predicted_logits = n_predictions = prediction = gaussian = workon = None results_device = self.device if do_on_device else torch.device('cpu') def producer(d, slh, q): for s in slh: q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s)) q.put('end') try: empty_cache(self.device) # move data to device if self.verbose: print(f'move image to device {results_device}') data = data.to(results_device) queue = Queue(maxsize=2) t = Thread(target=producer, args=(data, slicers, queue)) t.start() # preallocate arrays if self.verbose: print(f'preallocating results arrays on device {results_device}') predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), dtype=torch.half, device=results_device) n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) if self.use_gaussian: gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, value_scaling_factor=10, device=results_device) else: gaussian = 1 if not self.allow_tqdm and self.verbose: print(f'running prediction: {len(slicers)} steps') with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar: while True: item = queue.get() if item == 'end': queue.task_done() break workon, sl = item prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) if self.use_gaussian: prediction *= gaussian predicted_logits[sl] += prediction n_predictions[sl[1:]] += gaussian queue.task_done() pbar.update() queue.join() # predicted_logits /= n_predictions torch.div(predicted_logits, n_predictions, out=predicted_logits) # check for infs if torch.any(torch.isinf(predicted_logits)): raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' 'predicted_logits to fp32') except Exception as e: del predicted_logits, n_predictions, prediction, gaussian, workon empty_cache(self.device) empty_cache(results_device) raise e return predicted_logits @torch.inference_mode() def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ -> Union[np.ndarray, torch.Tensor]: assert isinstance(input_image, torch.Tensor) self.network = self.network.to(self.device) self.network.eval() empty_cache(self.device) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) # and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' if self.verbose: print(f'Input shape: {input_image.shape}') print("step_size:", self.tile_step_size) print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) # if input_image is smaller than tile_size we need to pad it to tile_size. data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, 'constant', {'value': 0}, True, None) slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) if self.perform_everything_on_device and self.device != 'cpu': # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device try: predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) except RuntimeError: print( 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') empty_cache(self.device) predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) else: predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, self.perform_everything_on_device) empty_cache(self.device) # revert padding predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] return predicted_logits def predict_from_files_sequential(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], save_probabilities: bool = False, overwrite: bool = True, folder_with_segs_from_prev_stage: str = None): """ Just like predict_from_files but doesn't use any multiprocessing. Slow, but sometimes necessary """ if isinstance(output_folder_or_list_of_truncated_output_files, str): output_folder = output_folder_or_list_of_truncated_output_files elif isinstance(output_folder_or_list_of_truncated_output_files, list): output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) if len(output_folder) == 0: # just a file was given without a folder output_folder = os.path.curdir else: output_folder = None ######################## # let's store the input arguments so that its clear what was used to generate the prediction if output_folder is not None: my_init_kwargs = {} for k in inspect.signature(self.predict_from_files_sequential).parameters.keys(): my_init_kwargs[k] = locals()[k] my_init_kwargs = deepcopy( my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a recursive_fix_for_json_export(my_init_kwargs) save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) # we need these two if we want to do things with the predictions like for example apply postprocessing save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) ####################### # check if we need a prediction from the previous stage if self.configuration_manager.previous_stage_name is not None: assert folder_with_segs_from_prev_stage is not None, \ f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ f' they are located via folder_with_segs_from_prev_stage' # sort out input and output filenames list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ self._manage_input_and_output_lists(list_of_lists_or_source_folder, output_folder_or_list_of_truncated_output_files, folder_with_segs_from_prev_stage, overwrite, 0, 1, save_probabilities) if len(list_of_lists_or_source_folder) == 0: return label_manager = self.plans_manager.get_label_manager(self.dataset_json) preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose) if output_filename_truncated is None: output_filename_truncated = [None] * len(list_of_lists_or_source_folder) if seg_from_prev_stage_files is None: seg_from_prev_stage_files = [None] * len(seg_from_prev_stage_files) ret = [] for li, of, sps in zip(list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files): data, seg, data_properties = preprocessor.run_case( li, sps, self.plans_manager, self.configuration_manager, self.dataset_json ) print(f'perform_everything_on_device: {self.perform_everything_on_device}') prediction = self.predict_logits_from_preprocessed_data(torch.from_numpy(data)).cpu() if of is not None: export_prediction_from_logits(prediction, data_properties, self.configuration_manager, self.plans_manager, self.dataset_json, of, save_probabilities) else: ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager, self.configuration_manager, self.label_manager, data_properties, save_probabilities)) # clear lru cache compute_gaussian.cache_clear() # clear device cache empty_cache(self.device) return ret def _getDefaultValue(env: str, dtype: type, default: any,) -> any: try: val = dtype(os.environ.get(env) or default) except: val = default return val def predict_entry_point_modelfolder(): import argparse parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' 'you want to manually specify a folder containing a trained nnU-Net ' 'model. This is useful when the nnunet environment variables ' '(nnUNet_results) are not set.') parser.add_argument('-i', type=str, required=True, help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' 'File endings must be the same as the training dataset!') parser.add_argument('-o', type=str, required=True, help='Output folder. If it does not exist it will be created. Predicted segmentations will ' 'have the same name as their source images.') parser.add_argument('-m', type=str, required=True, help='Folder in which the trained model is. Must have subfolders fold_X for the different ' 'folds you trained') parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), help='Specify the folds of the trained model that should be used for prediction. ' 'Default: (0, 1, 2, 3, 4)') parser.add_argument('-step_size', type=float, required=False, default=0.5, help='Step size for sliding window prediction. The larger it is the faster but less accurate ' 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', help='Set this to export predicted class "probabilities". Required if you want to ensemble ' 'multiple configurations.') parser.add_argument('--continue_prediction', '--c', action='store_true', help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') parser.add_argument('-npp', type=int, required=False, default=3, help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-nps', type=int, required=False, default=3, help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, help='Folder containing the predictions of the previous stage. Required for cascaded models.') parser.add_argument('-device', type=str, default='cuda', required=False, help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' 'jobs)') parser.add_argument('--not_on_device', action='store_true', required=False, default=False, help="Set this flag to disable perform_everything_on_device. Recommended for large cases that " "occupy more VRAM than available") print( "\n#######################################################################\nPlease cite the following paper " "when using nnU-Net:\n" "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " "Nature methods, 18(2), 203-211.\n#######################################################################\n") args = parser.parse_args() args.f = [i if i == 'all' else int(i) for i in args.f] if not isdir(args.o): maybe_mkdir_p(args.o) assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' if args.device == 'cpu': # let's allow torch to use hella threads import multiprocessing torch.set_num_threads(multiprocessing.cpu_count()) device = torch.device('cpu') elif args.device == 'cuda': # multithreading in torch doesn't help nnU-Net if run on GPU torch.set_num_threads(1) torch.set_num_interop_threads(1) device = torch.device('cuda') else: device = torch.device('mps') predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, perform_everything_on_device=not args.not_on_device, device=device, verbose=args.verbose, allow_tqdm=not args.disable_progress_bar, verbose_preprocessing=args.verbose) predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, num_processes_preprocessing=args.npp, num_processes_segmentation_export=args.nps, folder_with_segs_from_prev_stage=args.prev_stage_predictions, num_parts=1, part_id=0) def predict_entry_point(): import argparse parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' 'you want to manually specify a folder containing a trained nnU-Net ' 'model. This is useful when the nnunet environment variables ' '(nnUNet_results) are not set.') parser.add_argument('-i', type=str, required=True, help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' 'File endings must be the same as the training dataset!') parser.add_argument('-o', type=str, required=True, help='Output folder. If it does not exist it will be created. Predicted segmentations will ' 'have the same name as their source images.') parser.add_argument('-d', type=str, required=True, help='Dataset with which you would like to predict. You can specify either dataset name or id') parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', help='Plans identifier. Specify the plans in which the desired configuration is located. ' 'Default: nnUNetPlans') parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer') parser.add_argument('-c', type=str, required=True, help='nnU-Net configuration that should be used for prediction. Config must be located ' 'in the plans specified with -p') parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), help='Specify the folds of the trained model that should be used for prediction. ' 'Default: (0, 1, 2, 3, 4)') parser.add_argument('-step_size', type=float, required=False, default=0.5, help='Step size for sliding window prediction. The larger it is the faster but less accurate ' 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') parser.add_argument('--disable_tta', action='store_true', required=False, default=False, help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' 'but less accurate inference. Not recommended.') parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " "to be a good listener/reader.") parser.add_argument('--save_probabilities', action='store_true', help='Set this to export predicted class "probabilities". Required if you want to ensemble ' 'multiple configurations.') parser.add_argument('--continue_prediction', action='store_true', help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3), help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3), help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, help='Folder containing the predictions of the previous stage. Required for cascaded models.') parser.add_argument('-num_parts', type=int, required=False, default=1, help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one ' 'call predicts everything)') parser.add_argument('-part_id', type=int, required=False, default=0, help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with ' 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts ' '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible ' 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)') parser.add_argument('-device', type=str, default='cuda', required=False, help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' 'jobs)') parser.add_argument('--not_on_device', action='store_true', required=False, default=False, help="Set this flag to disable perform_everything_on_device. Recommended for large cases that " "occupy more VRAM than available") print( "\n#######################################################################\nPlease cite the following paper " "when using nnU-Net:\n" "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " "Nature methods, 18(2), 203-211.\n#######################################################################\n") args = parser.parse_args() args.f = [i if i == 'all' else int(i) for i in args.f] model_folder = get_output_folder(args.d, args.tr, args.p, args.c) if not isdir(args.o): maybe_mkdir_p(args.o) # slightly passive aggressive haha assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.' assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' if args.device == 'cpu': # let's allow torch to use hella threads import multiprocessing torch.set_num_threads(multiprocessing.cpu_count()) device = torch.device('cpu') elif args.device == 'cuda': # multithreading in torch doesn't help nnU-Net if run on GPU torch.set_num_threads(1) torch.set_num_interop_threads(1) device = torch.device('cuda') else: device = torch.device('mps') predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, perform_everything_on_device=not args.not_on_device, device=device, verbose=args.verbose, verbose_preprocessing=args.verbose, allow_tqdm=not args.disable_progress_bar) predictor.initialize_from_trained_model_folder( model_folder, args.f, checkpoint_name=args.chk ) run_sequential = args.nps == 0 and args.npp == 0 if run_sequential: print("Running in non-multiprocessing mode") predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, folder_with_segs_from_prev_stage=args.prev_stage_predictions) else: predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, num_processes_preprocessing=args.npp, num_processes_segmentation_export=args.nps, folder_with_segs_from_prev_stage=args.prev_stage_predictions, num_parts=args.num_parts, part_id=args.part_id) # r = predict_from_raw_data(args.i, # args.o, # model_folder, # args.f, # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, # overwrite=not args.continue_prediction, # checkpoint_name=args.chk, # num_processes_preprocessing=args.npp, # num_processes_segmentation_export=args.nps, # folder_with_segs_from_prev_stage=args.prev_stage_predictions, # num_parts=args.num_parts, # part_id=args.part_id, # device=device) if __name__ == '__main__': ########################## predict a bunch of files from nnunetv2.paths import nnUNet_results, nnUNet_raw predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=True ) predictor.initialize_from_trained_model_folder( join(nnUNet_results, 'Dataset004_Hippocampus/nnUNetTrainer_5epochs__nnUNetPlans__3d_fullres'), use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), # save_probabilities=False, overwrite=False, # num_processes_preprocessing=2, num_processes_segmentation_export=2, # folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) # # # predict a numpy array # from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO # # img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) # ret = predictor.predict_single_npy_array(img, props, None, None, False) # # iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) # ret = predictor.predict_from_data_iterator(iterator, False, 1) ret = predictor.predict_from_files_sequential( [['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_002_0000.nii.gz'], ['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_005_0000.nii.gz']], '/home/isensee/temp/tmp', False, True, None ) ================================================ FILE: nnunetv2/inference/readme.md ================================================ The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into your existing workflows. This readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn all the details! # Preface In terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on the fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and sends the prediction off to another set of background workers which will resize the resulting logits, convert them to a segmentation and export the segmentation. The reason the default setup is the best option is because 1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can focus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing. This uses your resources as well as possible! 2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images and having to store all of them + the results in your system memory # nnUNetPredictor The new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your code can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each new prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal. The nnUNetPredictor must be ininitialized manually! You will want to use the `predictor.initialize_from_trained_model_folder` function for 99% of use cases! New feature: If you do not specify an output folder / output files then the predicted segmentations will be returned ## Recommended nnU-Net default: predict from source files tldr: - loads images on the fly - performs preprocessing in background workers - main process focuses only on making predictions - results are again given to background workers for resampling and (optional) export pros: - best suited for predicting a large number of images - nicer to your RAM cons: - not ideal when single images are to be predicted - requires images to be present as files Example: ```python from nnunetv2.paths import nnUNet_results, nnUNet_raw import torch from batchgenerators.utilities.file_and_folder_operations import join from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor # instantiate the nnUNetPredictor predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=True ) # initializes the network architecture, loads the checkpoint predictor.initialize_from_trained_model_folder( join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) # variant 1: give input and output folders predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), save_probabilities=False, overwrite=False, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) ``` Instead if giving input and output folders you can also give concrete files. If you give concrete files, there is no need for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames! Remember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted and the inner list contains all the files belonging to that case. There is just one file for datasets with just one input modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc). IMPORTANT: the order in which the files for each case are given must match the order of the channels as defined in the dataset.json! If you give files as input, you need to give individual output files as output! ```python # variant 2, use list of files as inputs. Note how we use nested lists!!! indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], [join(indir, 'liver_142_0000.nii.gz')]], [join(outdir, 'liver_152'), join(outdir, 'liver_142')], save_probabilities=False, overwrite=False, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) ``` Did you know? If you do not specify output files, the predicted segmentations will be returned: ```python # variant 2.5, returns segmentations indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], [join(indir, 'liver_142_0000.nii.gz')]], None, save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) ``` ## Prediction from npy arrays tldr: - you give images as a list of npy arrays - performs preprocessing in background workers - main process focuses only on making predictions - results are again given to background workers for resampling and (optional) export pros: - the correct variant for when you have images in RAM already - well suited for predicting multiple images cons: - uses more ram than the default - unsuited for large number of images as all images must be held in RAM ```python from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) # we do not set output files so that the segmentations will be returned. You can of course also specify output # files instead (no return value on that case) ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], None, [props, props2, props3, props4], None, 2, save_probabilities=False, num_processes_segmentation_export=2) ``` ## Predicting a single npy array tldr: - you give one image as npy array - axes ordering must match the corresponding training data. The easiest way to achieve that is to use the same I/O class for loading images as was used during nnU-Net preprocessing! You can find that class in your plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the default axis ordering for medical images is the one from SimpleITK. If you load with nibabel, you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]! - everything is done in the main process: preprocessing, prediction, resampling, (export) - no interlacing, slowest variant! - ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON pros: - no messing with multiprocessing - no messing with data iterator blabla cons: - slows as heck, yo - never the right choice unless you can only give a single image at a time to nnU-Net ```python # predict a single numpy array (SimpleITKIO) img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) ret = predictor.predict_single_npy_array(img, props, None, None, False) # predict a single numpy array (NibabelIO) img, props = NibabelIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) ret = predictor.predict_single_npy_array(img, props, None, None, False) # The following IS NOT RECOMMENDED. Use nnunetv2.imageio! # nibabel, we need to transpose axes and spacing to match the training axes ordering for the nnU-Net default: nib.load('Dataset003_Liver/imagesTr/liver_63_0000.nii.gz') img = np.asanyarray(img_nii.dataobj).transpose([2, 1, 0]) # reverse axis order to match SITK props = {'spacing': img_nii.header.get_zooms()[::-1]} # reverse axis order to match SITK ret = predictor.predict_single_npy_array(img, props, None, None, False) ``` ## Predicting with a custom data iterator tldr: - highly flexible - not for newbies pros: - you can do everything yourself - you have all the freedom you want - really fast if you remember to use multiprocessing in your iterator cons: - you need to do everything yourself - harder than you might think ```python img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! # If 'ofile' is None, the result will be returned instead of written to a file # the iterator is responsible for performing the correct preprocessing! # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays # (they both use predictor.predict_from_data_iterator) for inspiration! def my_iterator(list_of_input_arrs, list_of_input_props): preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) for a, p in zip(list_of_input_arrs, list_of_input_props): data, seg = preprocessor.run_case_npy(a, None, p, predictor.plans_manager, predictor.configuration_manager, predictor.dataset_json) yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None} ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), save_probabilities=False, num_processes_segmentation_export=3) ``` ================================================ FILE: nnunetv2/inference/sliding_window_prediction.py ================================================ from functools import lru_cache import numpy as np import torch from typing import Union, Tuple, List from acvl_utils.cropping_and_padding.padding import pad_nd_image from scipy.ndimage import gaussian_filter @lru_cache(maxsize=2) def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ -> torch.Tensor: tmp = np.zeros(tile_size) center_coords = [i // 2 for i in tile_size] sigmas = [i * sigma_scale for i in tile_size] tmp[tuple(center_coords)] = 1 gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) gaussian_importance_map = torch.from_numpy(gaussian_importance_map) gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor) gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype) # gaussian_importance_map cannot be 0, otherwise we may end up with nans! mask = gaussian_importance_map == 0 gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask]) return gaussian_importance_map def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \ List[List[int]]: assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] steps = [] for dim in range(len(tile_size)): # the highest step value for this dimension is max_step_value = image_size[dim] - tile_size[dim] if num_steps[dim] > 1: actual_step_size = max_step_value / (num_steps[dim] - 1) else: actual_step_size = 99999999999 # does not matter because there is only one step at 0 steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] steps.append(steps_here) return steps if __name__ == '__main__': a = torch.rand((4, 2, 32, 23)) a_npy = a.numpy() a_padded = pad_nd_image(a, new_shape=(48, 27)) a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27)) assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))]) assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))]) assert np.all(a_padded.numpy() == a_npy_padded) ================================================ FILE: nnunetv2/model_sharing/__init__.py ================================================ ================================================ FILE: nnunetv2/model_sharing/entry_points.py ================================================ from nnunetv2.model_sharing.model_download import download_and_install_from_url from nnunetv2.model_sharing.model_export import export_pretrained_model from nnunetv2.model_sharing.model_import import install_model_from_zip_file def print_license_warning(): print('') print('######################################################') print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') print('######################################################') print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") print('######################################################') print('') def download_by_url(): import argparse parser = argparse.ArgumentParser( description="Use this to download pretrained models. This script is intended to download models via url only. " "CAREFUL: This script will overwrite " "existing models (if they share the same trainer class and plans as " "the pretrained model.") parser.add_argument("url", type=str, help='URL of the pretrained model') args = parser.parse_args() url = args.url download_and_install_from_url(url) def install_from_zip_entry_point(): import argparse parser = argparse.ArgumentParser( description="Use this to install a zip file containing a pretrained model.") parser.add_argument("zip", type=str, help='zip file') args = parser.parse_args() zip = args.zip install_model_from_zip_file(zip) def export_pretrained_model_entry(): import argparse parser = argparse.ArgumentParser( description="Use this to export a trained model as a zip file.") parser.add_argument('-d', type=str, required=True, help='Dataset name or id') parser.add_argument('-o', type=str, required=True, help='Output file name') parser.add_argument('-c', nargs='+', type=str, required=False, default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), help="List of configuration names") parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') args = parser.parse_args() export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, export_crossval_predictions=args.exp_cv_preds) ================================================ FILE: nnunetv2/model_sharing/model_download.py ================================================ from typing import Optional import requests from batchgenerators.utilities.file_and_folder_operations import * from time import time from nnunetv2.model_sharing.model_import import install_model_from_zip_file from nnunetv2.paths import nnUNet_results from tqdm import tqdm def download_and_install_from_url(url): assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \ "set (RESULTS_FOLDER missing as environment variable, see " \ "Installation instructions)" print('Downloading pretrained model from url:', url) import http.client http.client.HTTPConnection._http_vsn = 10 http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0' import os home = os.path.expanduser('~') random_number = int(time() * 1e7) tempfile = join(home, f'.nnunetdownload_{str(random_number)}') try: download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16) print("Download finished. Extracting...") install_model_from_zip_file(tempfile) print("Done") except Exception as e: raise e finally: if isfile(tempfile): os.remove(tempfile) def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str: # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests # NOTE the stream=True parameter below with requests.get(url, stream=True, timeout=100) as r: r.raise_for_status() with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) return local_filename ================================================ FILE: nnunetv2/model_sharing/model_export.py ================================================ import zipfile from nnunetv2.utilities.file_path_utilities import * def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str, configurations: Tuple[str] = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), trainer: str = 'nnUNetTrainer', plans_identifier: str = 'nnUNetPlans', folds: Tuple[int, ...] = (0, 1, 2, 3, 4), strict: bool = True, save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',), export_crossval_predictions: bool = False) -> None: dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf: for c in configurations: print(f"Configuration {c}") trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c) if not isdir(trainer_output_dir): if strict: raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}") else: continue expected_fold_folder = [f"fold_{i}" if i != 'all' else 'fold_all' for i in folds] assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \ f"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}" assert isfile(join(trainer_output_dir, "plans.json")), f"plans.json missing, {dataset_name} {c}" for fold_folder in expected_fold_folder: print(f"Exporting {fold_folder}") # debug.json, does not exist yet source_file = join(trainer_output_dir, fold_folder, "debug.json") if isfile(source_file): zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # all requested checkpoints for chk in save_checkpoints: source_file = join(trainer_output_dir, fold_folder, chk) zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # progress.png source_file = join(trainer_output_dir, fold_folder, "progress.png") zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # if it exists, network architecture.png source_file = join(trainer_output_dir, fold_folder, "network_architecture.pdf") if isfile(source_file): zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # validation folder with all predicted segmentations etc if export_crossval_predictions: source_folder = join(trainer_output_dir, fold_folder, "validation") files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')] for f in files: zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results)) # just the summary.json file from the validation else: source_file = join(trainer_output_dir, fold_folder, "validation", "summary.json") zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}') if isdir(source_folder): if export_crossval_predictions: source_files = subfiles(source_folder, join=True) else: source_files = [ join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] ] for s in source_files: if isfile(s): zipf.write(s, os.path.relpath(s, nnUNet_results)) # plans source_file = join(trainer_output_dir, "plans.json") zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # fingerprint source_file = join(trainer_output_dir, "dataset_fingerprint.json") zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) # dataset source_file = join(trainer_output_dir, "dataset.json") zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles') if not isdir(ensemble_dir): print("No ensemble directory found for task", dataset_name_or_id) return subd = subdirs(ensemble_dir, join=False) # figure out whether the models in the ensemble are all within the exported models here for ens in subd: identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens) ok = True for i in identifiers: tr, pl, c = convert_identifier_to_trainer_plans_config(i) if tr == trainer and pl == plans_identifier and c in configurations: pass else: ok = False if ok: print(f'found matching ensemble: {ens}') source_folder = join(ensemble_dir, ens) if export_crossval_predictions: source_files = subfiles(source_folder, join=True) else: source_files = [ join(source_folder, i) for i in ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i)) ] for s in source_files: zipf.write(s, os.path.relpath(s, nnUNet_results)) inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json') if isfile(inference_information_file): zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results)) inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt') if isfile(inference_information_txt_file): zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results)) print('Done') if __name__ == '__main__': export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, )) ================================================ FILE: nnunetv2/model_sharing/model_import.py ================================================ import zipfile from nnunetv2.paths import nnUNet_results def install_model_from_zip_file(zip_file: str): with zipfile.ZipFile(zip_file, 'r') as zip_ref: zip_ref.extractall(nnUNet_results) ================================================ FILE: nnunetv2/paths.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os """ PLEASE READ documentation/setting_up_paths.md FOR INFORMATION TO HOW TO SET THIS UP """ nnUNet_raw = os.environ.get('nnUNet_raw') nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed') nnUNet_results = os.environ.get('nnUNet_results') if nnUNet_raw is None: print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files " "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like " "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set " "this up properly.") if nnUNet_preprocessed is None: print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing " "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how " "to set this up.") if nnUNet_results is None: print("nnUNet_results is not defined and nnU-Net cannot be used for training or " "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information " "on how to set this up.") ================================================ FILE: nnunetv2/postprocessing/__init__.py ================================================ ================================================ FILE: nnunetv2/postprocessing/remove_connected_components.py ================================================ import argparse import multiprocessing import shutil from typing import Union, Tuple, List, Callable import numpy as np from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \ isdir, save_pickle, load_pickle, save_json from nnunetv2.configuration import default_num_processes from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \ load_summary_json, label_or_region_to_key from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.paths import nnUNet_raw from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.plans_handling.plans_handler import PlansManager def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray, labels_or_regions: Union[int, Tuple[int, ...], List[Union[int, Tuple[int, ...]]]], background_label: int = 0) -> np.ndarray: mask = np.zeros_like(segmentation, dtype=bool) if not isinstance(labels_or_regions, list): labels_or_regions = [labels_or_regions] for l_or_r in labels_or_regions: mask |= region_or_label_to_mask(segmentation, l_or_r) mask_keep = remove_all_but_largest_component(mask) ret = np.copy(segmentation) # do not modify the input! ret[mask & ~mask_keep] = background_label return ret def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): for fn, kwargs in zip(pp_fns, pp_fn_kwargs): segmentation = fn(segmentation, **kwargs) return segmentation def load_postprocess_save(segmentation_file: str, output_fname: str, image_reader_writer: BaseReaderWriter, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): seg, props = image_reader_writer.read_seg(segmentation_file) seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs) image_reader_writer.write_seg(seg, output_fname, props) def determine_postprocessing(folder_predictions: str, folder_ref: str, plans_file_or_dict: Union[str, dict], dataset_json_file_or_dict: Union[str, dict], num_processes: int = default_num_processes, keep_postprocessed_files: bool = True): """ Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be used with apply_postprocessing_to_folder. Postprocessed files are saved in folder_predictions/postprocessed. Set keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created and deleted regardless). If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder """ output_folder = join(folder_predictions, 'postprocessed') if plans_file_or_dict is None: expected_plans_file = join(folder_predictions, 'plans.json') if not isfile(expected_plans_file): raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been " f"created while running nnUNetv2_predict. Sadge.") plans_file_or_dict = load_json(expected_plans_file) plans_manager = PlansManager(plans_file_or_dict) if dataset_json_file_or_dict is None: expected_dataset_json_file = join(folder_predictions, 'dataset.json') if not isfile(expected_dataset_json_file): raise RuntimeError( f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been " f"created while running nnUNetv2_predict. Sadge.") dataset_json_file_or_dict = load_json(expected_dataset_json_file) if not isinstance(dataset_json_file_or_dict, dict): dataset_json = load_json(dataset_json_file_or_dict) else: dataset_json = dataset_json_file_or_dict rw = plans_manager.image_reader_writer_class() label_manager = plans_manager.get_label_manager(dataset_json) labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False) ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False) # we should print a warning if not all files from folder_ref are present in folder_predictions if not all([i in predicted_files for i in ref_files]): print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing ' f'should always be done on the entire dataset!') # before we start we should evaluate the imaegs in the source folder if not isfile(join(folder_predictions, 'summary.json')): compute_metrics_on_folder(folder_ref, folder_predictions, join(folder_predictions, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) # we save the postprocessing functions in here pp_fns = [] pp_fn_kwargs = [] # pool party! with multiprocessing.get_context("spawn").Pool(num_processes) as pool: # now let's see whether removing all but the largest foreground region improves the scores output_here = join(output_folder, 'temp', 'keep_largest_fg') maybe_mkdir_p(output_here) pp_fn = remove_all_but_largest_component_from_segmentation kwargs = { 'labels_or_regions': label_manager.foreground_labels, } pool.starmap( load_postprocess_save, zip( [join(folder_predictions, i) for i in predicted_files], [join(output_here, i) for i in predicted_files], [rw] * len(predicted_files), [[pp_fn]] * len(predicted_files), [[kwargs]] * len(predicted_files) ) ) compute_metrics_on_folder(folder_ref, output_here, join(output_here, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far # that if a single class got worse as a result we won't do this. We can change this in the future but right now I # prefer to do it this way baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) pp_results = load_summary_json(join(output_here, 'summary.json')) do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice'] if do_this: for class_id in pp_results['mean'].keys(): if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']: do_this = False break if do_this: print(f'Results were improved by removing all but the largest foreground region. ' f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} ' f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}') source = output_here pp_fns.append(pp_fn) pp_fn_kwargs.append(kwargs) else: print(f'Removing all but the largest foreground region did not improve results!') source = folder_predictions # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and # then evaluate for each class whether this improved results. This is no longer possible because we now support # region-based predictions and regions can overlap, causing interactions # in principle the order with which the postprocessing is applied to the regions matter as well and should be # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think) # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about # region_class_order) # 2023_02_06: I hate myself for the comment above. Thanks past me if len(labels_or_regions) > 1: for label_or_region in labels_or_regions: pp_fn = remove_all_but_largest_component_from_segmentation kwargs = { 'labels_or_regions': label_or_region, } output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion') maybe_mkdir_p(output_here) pool.starmap( load_postprocess_save, zip( [join(source, i) for i in predicted_files], [join(output_here, i) for i in predicted_files], [rw] * len(predicted_files), [[pp_fn]] * len(predicted_files), [[kwargs]] * len(predicted_files) ) ) compute_metrics_on_folder(folder_ref, output_here, join(output_here, 'summary.json'), rw, dataset_json['file_ending'], labels_or_regions, label_manager.ignore_label, num_processes) baseline_results = load_summary_json(join(source, 'summary.json')) pp_results = load_summary_json(join(output_here, 'summary.json')) do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice'] if do_this: print(f'Results were improved by removing all but the largest component for {label_or_region}. ' f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')): shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')) shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), ) source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest') pp_fns.append(pp_fn) pp_fn_kwargs.append(kwargs) else: print(f'Removing all but the largest component for {label_or_region} did not improve results! ' f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)] save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl')) baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) final_results = load_summary_json(join(output_folder, 'summary.json')) tmp = { 'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']}, 'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']}, 'postprocessing_fns': [i.__name__ for i in pp_fns], 'postprocessing_kwargs': pp_fn_kwargs, } # json is very annoying. Can't handle tuples as dict keys. tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in tmp['input_folder']['mean'].keys()} tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in tmp['postprocessed']['mean'].keys()} # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" recursive_fix_for_json_export(tmp) save_json(tmp, join(folder_predictions, 'postprocessing.json')) shutil.rmtree(join(output_folder, 'temp')) if not keep_postprocessed_files: shutil.rmtree(output_folder) return pp_fns, pp_fn_kwargs def apply_postprocessing_to_folder(input_folder: str, output_folder: str, pp_fns: List[Callable], pp_fn_kwargs: List[dict], plans_file_or_dict: Union[str, dict] = None, dataset_json_file_or_dict: Union[str, dict] = None, num_processes=8) -> None: """ If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder """ if plans_file_or_dict is None: expected_plans_file = join(input_folder, 'plans.json') if not isfile(expected_plans_file): raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been " f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply " f"postprocessing to was create from an ensemble then just specify one of the " f"plans files of the ensemble members in plans_file_or_dict") plans_file_or_dict = load_json(expected_plans_file) plans_manager = PlansManager(plans_file_or_dict) if dataset_json_file_or_dict is None: expected_dataset_json_file = join(input_folder, 'dataset.json') if not isfile(expected_dataset_json_file): raise RuntimeError( f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been " f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.") dataset_json_file_or_dict = load_json(expected_dataset_json_file) if not isinstance(dataset_json_file_or_dict, dict): dataset_json = load_json(dataset_json_file_or_dict) else: dataset_json = dataset_json_file_or_dict rw = plans_manager.image_reader_writer_class() maybe_mkdir_p(output_folder) with multiprocessing.get_context("spawn").Pool(num_processes) as p: files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False) _ = p.starmap(load_postprocess_save, zip( [join(input_folder, i) for i in files], [join(output_folder, i) for i in files], [rw] * len(files), [pp_fns] * len(files), [pp_fn_kwargs] * len(files) ) ) def entry_point_determine_postprocessing_folder(): parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.') parser.add_argument('-i', type=str, required=True, help='Input folder') parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels') parser.add_argument('-plans_json', type=str, required=False, default=None, help="plans file to use. If not specified we will look for the plans.json file in the " "input folder (input_folder/plans.json)") parser.add_argument('-dataset_json', type=str, required=False, default=None, help="dataset.json file to use. If not specified we will look for the dataset.json file in the " "input folder (input_folder/dataset.json)") parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f"number of processes to use. Default: {default_num_processes}") parser.add_argument('--remove_postprocessed', action='store_true', required=False, help='set this is you don\'t want to keep the postprocessed files') args = parser.parse_args() determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np, not args.remove_postprocessed) def entry_point_apply_postprocessing(): parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.') parser.add_argument('-i', type=str, required=True, help='Input folder') parser.add_argument('-o', type=str, required=True, help='Output folder') parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file') parser.add_argument('-np', type=int, required=False, default=default_num_processes, help=f"number of processes to use. Default: {default_num_processes}") parser.add_argument('-plans_json', type=str, required=False, default=None, help="plans file to use. If not specified we will look for the plans.json file in the " "input folder (input_folder/plans.json)") parser.add_argument('-dataset_json', type=str, required=False, default=None, help="dataset.json file to use. If not specified we will look for the dataset.json file in the " "input folder (input_folder/dataset.json)") args = parser.parse_args() pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file) apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np) if __name__ == '__main__': trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres' labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr') plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) dataset_json = load_json(join(trained_model_folder, 'dataset.json')) folds = (0, 1, 2, 3, 4) label_manager = plans_manager.get_label_manager(dataset_json) merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False) fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans, dataset_json, 8, keep_postprocessed_files=True) save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl')) fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl')) apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs, plans_manager.plans, dataset_json, 8) compute_metrics_on_folder(labelstr, merged_output_folder + '_pp', join(merged_output_folder + '_pp', 'summary.json'), plans_manager.image_reader_writer_class(), dataset_json['file_ending'], label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, label_manager.ignore_label, 8) ================================================ FILE: nnunetv2/preprocessing/__init__.py ================================================ ================================================ FILE: nnunetv2/preprocessing/cropping/__init__.py ================================================ ================================================ FILE: nnunetv2/preprocessing/cropping/cropping.py ================================================ import numpy as np from scipy.ndimage import binary_fill_holes from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice def create_nonzero_mask(data): """ :param data: :return: the mask is True where the data is nonzero """ assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)" nonzero_mask = data[0] != 0 for c in range(1, data.shape[0]): nonzero_mask |= data[c] != 0 return binary_fill_holes(nonzero_mask) def crop_to_nonzero(data, seg=None, nonzero_label=-1): """ :param data: :param seg: :param nonzero_label: this will be written into the segmentation map :return: """ nonzero_mask = create_nonzero_mask(data) bbox = get_bbox_from_mask(nonzero_mask) slicer = bounding_box_to_slice(bbox) nonzero_mask = nonzero_mask[slicer][None] slicer = (slice(None), ) + slicer data = data[slicer] if seg is not None: seg = seg[slicer] seg[(seg == 0) & (~nonzero_mask)] = nonzero_label else: seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label)) return data, seg, bbox ================================================ FILE: nnunetv2/preprocessing/normalization/__init__.py ================================================ ================================================ FILE: nnunetv2/preprocessing/normalization/default_normalization_schemes.py ================================================ from abc import ABC, abstractmethod from typing import Type import numpy as np from numpy import number class ImageNormalization(ABC): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None, target_dtype: Type[number] = np.float32): assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool) self.use_mask_for_norm = use_mask_for_norm assert isinstance(intensityproperties, dict) self.intensityproperties = intensityproperties self.target_dtype = target_dtype @abstractmethod def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: """ Image and seg must have the same shape. Seg is not always used """ pass class ZScoreNormalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: """ here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by default. """ eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4 image = image.astype(self.target_dtype, copy=False) if self.use_mask_for_norm: assert seg is not None, ("use_mask_for_norm is set, please provide a mask for the nonzero areas of the " "image via seg. The mask will be computed as `mask = seg >= 0`. You can use " "create_nonzero_mask from nnunetv2/preprocessing/cropping") if seg is not None and self.use_mask_for_norm: # negative values in the segmentation encode the 'outside' region (think zero values around the brain as # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially # reduced the image size. mask = seg >= 0 mean = image[mask].mean() std = image[mask].std() image[mask] = (image[mask] - mean) / (max(std, eps)) else: mean = image.mean() std = image.std() image -= mean image /= (max(std, eps)) return image class CTNormalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: assert self.intensityproperties is not None, "CTNormalization requires intensity properties" eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4 mean_intensity = self.intensityproperties['mean'] std_intensity = self.intensityproperties['std'] lower_bound = self.intensityproperties['percentile_00_5'] upper_bound = self.intensityproperties['percentile_99_5'] image = image.astype(self.target_dtype, copy=False) np.clip(image, lower_bound, upper_bound, out=image) image -= mean_intensity image /= max(std_intensity, eps) return image class NoNormalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: return image.astype(self.target_dtype, copy=False) class RescaleTo01Normalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4 image = image.astype(self.target_dtype, copy=False) image -= image.min() image /= np.clip(image.max(), a_min=eps, a_max=None) return image class RGBTo01Normalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \ "Your images do not seem to be RGB images" assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ ". Your images do not seem to be RGB images" image = image.astype(self.target_dtype, copy=False) image /= 255. return image ================================================ FILE: nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py ================================================ from typing import Type from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \ ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization channel_name_to_normalization_mapping = { 'ct': CTNormalization, 'nonorm': NoNormalization, 'zscore': ZScoreNormalization, 'rescale_to_0_1': RescaleTo01Normalization, 'rgb_to_0_1': RGBTo01Normalization } def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]: """ If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is not found, use the default (ZScoreNormalization) """ norm_scheme = channel_name_to_normalization_mapping.get(channel_name.casefold()) if norm_scheme is None: norm_scheme = ZScoreNormalization # print('Using %s for image normalization' % norm_scheme.__name__) return norm_scheme ================================================ FILE: nnunetv2/preprocessing/normalization/readme.md ================================================ The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different then you can just - create a new subclass of ImageNormalization - map your custom channel identifier to that subclass in channel_name_to_normalization_mapping - run plan and preprocess again with your custom normlaization scheme ================================================ FILE: nnunetv2/preprocessing/preprocessors/__init__.py ================================================ ================================================ FILE: nnunetv2/preprocessing/preprocessors/default_preprocessor.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing import shutil from time import sleep from typing import Tuple from typing import Union import SimpleITK import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from tqdm import tqdm import nnunetv2 from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2 from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets class DefaultPreprocessor(object): def __init__(self, verbose: bool = True): self.verbose = verbose self.show_progress_bar = True """ Everything we need is in the plans. Those are given when run() is called """ def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict, plans_manager: PlansManager, configuration_manager: ConfigurationManager, dataset_json: Union[dict, str]): # let's not mess up the inputs! data = data.astype(np.float32) # this creates a copy if seg is not None: assert data.shape[1:] == seg.shape[1:], "Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct" seg = np.copy(seg) has_seg = seg is not None # apply transpose_forward, this also needs to be applied to the spacing! data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) if seg is not None: seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward] # crop, remember to store size before cropping! shape_before_cropping = data.shape[1:] properties['shape_before_cropping'] = shape_before_cropping # this command will generate a segmentation. This is important because of the nonzero mask which we may need data, seg, bbox = crop_to_nonzero(data, seg) properties['bbox_used_for_cropping'] = bbox # print(data.shape, seg.shape) properties['shape_after_cropping_and_before_resampling'] = data.shape[1:] # resample target_spacing = configuration_manager.spacing # this should already be transposed if len(target_spacing) < len(data.shape[1:]): # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d # in 2d configuration we do not change the spacing between slices target_spacing = [original_spacing[0]] + target_spacing new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing) # normalize # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no # longer fitting the images perfectly! data = self._normalize(data, seg, configuration_manager, plans_manager.foreground_intensity_properties_per_channel) # print('current shape', data.shape[1:], 'current_spacing', original_spacing, # '\ntarget shape', new_shape, 'target_spacing', target_spacing) old_shape = data.shape[1:] data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing) seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing) if self.verbose: print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, ' f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}') # if we have a segmentation, sample foreground locations for oversampling and add those to properties if has_seg: # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument # with a LabelManager Instance in this function because that's all its used for. Dunno what's better. # LabelManager is pretty light computation-wise. label_manager = plans_manager.get_label_manager(dataset_json) collect_for_this = label_manager.foreground_regions if label_manager.has_regions \ else label_manager.foreground_labels # when using the ignore label we want to sample only from annotated regions. Therefore we also need to # collect samples uniformly from all classes (incl background) if label_manager.has_ignore_label: collect_for_this.append([-1] + label_manager.all_labels) # no need to filter background in regions because it is already filtered in handle_labels # print(all_labels, regions) properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this, verbose=self.verbose) seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager) if np.max(seg) > 127: seg = seg.astype(np.int16) else: seg = seg.astype(np.int8) return data, seg, properties def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager, configuration_manager: ConfigurationManager, dataset_json: Union[dict, str]): """ seg file can be none (test cases) order of operations is: transpose -> crop -> resample so when we export we need to run the following order: resample -> crop -> transpose (we could also run transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner) """ if isinstance(dataset_json, str): dataset_json = load_json(dataset_json) rw = plans_manager.image_reader_writer_class() # load image(s) data, data_properties = rw.read_images(image_files) # if possible, load seg if seg_file is not None: seg, _ = rw.read_seg(seg_file) else: seg = None if self.verbose: print(seg_file) data, seg, data_properties = self.run_case_npy(data, seg, data_properties, plans_manager, configuration_manager, dataset_json) return data, seg, data_properties def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str, plans_manager: PlansManager, configuration_manager: ConfigurationManager, dataset_json: Union[dict, str]): data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json) data = data.astype(np.float32, copy=False) seg = seg.astype(np.int16, copy=False) # print('dtypes', data.dtype, seg.dtype) block_size_data, chunk_size_data = nnUNetDatasetBlosc2.comp_blosc2_params( data.shape, tuple(configuration_manager.patch_size), data.itemsize) block_size_seg, chunk_size_seg = nnUNetDatasetBlosc2.comp_blosc2_params( seg.shape, tuple(configuration_manager.patch_size), seg.itemsize) nnUNetDatasetBlosc2.save_case(data, seg, properties, output_filename_truncated, chunks=chunk_size_data, blocks=block_size_data, chunks_seg=chunk_size_seg, blocks_seg=block_size_seg) @staticmethod def _sample_foreground_locations( seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]], seed: int = 1234, verbose: bool = False, min_num_samples=10000, min_percent_coverage = 0.01 ): rndst = np.random.RandomState(seed) class_locs = {} # Normalize requested labels and compute the set of all labels we might need normalized = [] requested_labels = set() for c in classes_or_regions: if isinstance(c, (tuple, list)): labs = tuple(int(x) for x in c) normalized.append(labs) requested_labels.update(labs) else: lab = int(c) normalized.append(lab) requested_labels.add(lab) # Create mask for all requested labels (this includes 0 if requested) requested_labels_arr = np.fromiter(requested_labels, dtype=np.int32) valid_mask = np.isin(seg, requested_labels_arr) coords = np.argwhere(valid_mask) seg_sel = seg[valid_mask] del valid_mask n = seg_sel.size if n == 0: for c in classes_or_regions: k = tuple(c) if isinstance(c, (tuple, list)) else int(c) class_locs[k] = [] return class_locs # sort once, then compute label blocks order = np.argsort(seg_sel, kind="stable") lab_sorted = seg_sel[order] coords_sorted = coords[order] change = np.flatnonzero(lab_sorted[1:] != lab_sorted[:-1]) + 1 starts = np.r_[0, change] ends = np.r_[change, n] labels_present = lab_sorted[starts] label_to_range = {int(l): (int(s), int(e)) for l, s, e in zip(labels_present, starts, ends)} present_labels = set(label_to_range.keys()) for c in classes_or_regions: is_region = isinstance(c, (tuple, list)) labs = tuple(int(x) for x in c) if is_region else (int(c),) k = labs if is_region else labs[0] # Skip if none of the labels are present if not any(lab in present_labels for lab in labs): class_locs[k] = [] continue # Collect ranges for present labels in this class/region ranges = [] counts = [] for lab in labs: r = label_to_range.get(lab) if r is None: continue s, e = r cnt = e - s if cnt > 0: ranges.append((s, e)) counts.append(cnt) if len(counts) == 0: class_locs[k] = [] continue total = int(np.sum(counts)) target_num_samples = min(min_num_samples, total) target_num_samples = max(target_num_samples, int(np.ceil(total * min_percent_coverage))) # Sample uniformly without replacement from the union of ranges, without building an n-sized mask # Draw target_num_samples unique offsets in [0, total) offsets = rndst.choice(total, target_num_samples, replace=False) # Map offsets -> (range index, in-range offset) using cumulative counts cum = np.cumsum(counts) which = np.searchsorted(cum, offsets, side="right") prev = np.concatenate(([0], cum[:-1])) in_range = offsets - prev[which] # Convert to indices in coords_sorted starts_for_pick = np.fromiter((ranges[i][0] for i in which), dtype=np.int64, count=which.size) picked_idx = starts_for_pick + in_range.astype(np.int64) selected = coords_sorted[picked_idx] class_locs[k] = selected if verbose: print(c, target_num_samples) return class_locs # @staticmethod # def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]], # seed: int = 1234, verbose: bool = False): # num_samples = 10000 # min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too # # sparse # rndst = np.random.RandomState(seed) # class_locs = {} # foreground_mask = seg != 0 # foreground_coords = np.argwhere(foreground_mask) # seg = seg[foreground_mask] # del foreground_mask # unique_labels = pd.unique(seg.ravel()) # # # We don't need more than 1e7 foreground samples. That's insanity. Cap here # if len(foreground_coords) > 1e7: # take_every = math.floor(len(foreground_coords) / 1e7) # # keep computation time reasonable # if verbose: # print(f'Subsampling foreground pixels 1:{take_every} for computational reasons') # foreground_coords = foreground_coords[::take_every] # seg = seg[::take_every] # # for c in classes_or_regions: # k = c if not isinstance(c, list) else tuple(c) # # # check if any of the labels are in seg, if not skip c # if isinstance(c, (tuple, list)): # if not any([ci in unique_labels for ci in c]): # class_locs[k] = [] # continue # else: # if c not in unique_labels: # class_locs[k] = [] # continue # # if isinstance(c, (tuple, list)): # mask = seg == c[0] # for cc in c[1:]: # mask = mask | (seg == cc) # all_locs = foreground_coords[mask] # else: # mask = seg == c # all_locs = foreground_coords[mask] # if len(all_locs) == 0: # class_locs[k] = [] # continue # target_num_samples = min(num_samples, len(all_locs)) # target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage))) # # selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)] # class_locs[k] = selected # if verbose: # print(c, target_num_samples) # seg = seg[~mask] # foreground_coords = foreground_coords[~mask] # return class_locs def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager, foreground_intensity_properties_per_channel: dict) -> np.ndarray: for c in range(data.shape[0]): scheme = configuration_manager.normalization_schemes[c] normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"), scheme, 'nnunetv2.preprocessing.normalization') if normalizer_class is None: raise RuntimeError(f'Unable to locate class \'{scheme}\' for normalization') normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c], intensityproperties=foreground_intensity_properties_per_channel[str(c)]) data[c] = normalizer.run(data[c], seg[0]) return data def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str, num_processes: int): """ data identifier = configuration name in plans. EZ. """ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw" plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \ "first." % plans_file plans = load_json(plans_file) plans_manager = PlansManager(plans) configuration_manager = plans_manager.get_configuration(configuration_name) dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json') dataset_json = load_json(dataset_json_file) output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier) if isdir(output_directory): shutil.rmtree(output_directory) maybe_mkdir_p(output_directory) dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames] # output_filenames_truncated = [join(output_directory, i) for i in identifiers] # multiprocessing magic. r = [] with multiprocessing.get_context("spawn").Pool(num_processes) as p: remaining = list(range(len(dataset))) # p is pretty nifti. If we kill workers they just respawn but don't do any work. # So we need to store the original pool of workers. workers = [j for j in p._pool] for k in dataset.keys(): r.append(p.starmap_async(self.run_case_save, ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], plans_manager, configuration_manager, dataset_json),))) with tqdm(desc="Preprocessing cases", total=len(dataset), disable=not getattr(self, 'show_progress_bar', True)) as pbar: while len(remaining) > 0: all_alive = all([j.is_alive() for j in workers]) if not all_alive: raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' 'OK jokes aside.\n' 'One of your background processes is missing. This could be because of ' 'an error (look for an error message) or because it was killed ' 'by your OS due to running out of RAM. If you don\'t see ' 'an error message, out of RAM is likely the problem. In that case ' 'reducing the number of workers might help') done = [i for i in remaining if r[i].ready()] for i in done: r[i].get() # trigger any errors from worker (single call, no duplicate) pbar.update() remaining = [i for i in remaining if i not in done] sleep(0.1) def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager) -> np.ndarray: # this function will be called at the end of self.run_case. Can be used to change the segmentation # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling # and don't have to create a new dataset each time I modify my experiments return seg def example_test_case_preprocessing(): # (paths to files may need adaptations) plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json' dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json' input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz'] configuration = '3d_fullres' pp = DefaultPreprocessor() # _ because this position would be the segmentation if seg_file was not None (training case) # even if you have the segmentation, don't put the file there! You should always evaluate in the original # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as # specified by plans) plans_manager = PlansManager(plans_file) data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager, configuration_manager=plans_manager.get_configuration(configuration), dataset_json=dataset_json_file) # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO) return data def _verify_class_locations(shape, outfile, class_locs): import numpy as np import SimpleITK as sitk out = np.zeros(shape, dtype=np.uint16) # allow many labels safely for i, k in enumerate(class_locs.keys()): class_coords = class_locs[k][:, 1:] if class_coords is None: continue class_coords = np.asarray(class_coords) if class_coords.size == 0: continue # Expect coords in (N, 3) as (z, y, x) if class_coords.ndim != 2 or class_coords.shape[1] != 3: raise ValueError(f"class_locs[{k}] must have shape (N, 3), got {class_coords.shape}") z = class_coords[:, 0].astype(np.int64) y = class_coords[:, 1].astype(np.int64) x = class_coords[:, 2].astype(np.int64) # Optional bounds check (cheap and prevents hard-to-debug indexing errors) if (z.min() < 0 or y.min() < 0 or x.min() < 0 or z.max() >= shape[0] or y.max() >= shape[1] or x.max() >= shape[2]): raise ValueError(f"Coordinates for {k} are out of bounds for shape={shape}") out[z, y, x] = i + 1 # label 1..K img = sitk.GetImageFromArray(out) # SimpleITK assumes array is z,y,x sitk.WriteImage(img, outfile) if __name__ == '__main__': # example_test_case_preprocessing() # pp = DefaultPreprocessor() # pp.run(2, '2d', 'nnUNetPlans', 8) ########################################################################################################### # how to process a test cases? This is an example: # example_test_case_preprocessing() seg = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage('/home/isensee/temp/H-mito-val-v2.nii.gz'))[None] a = DefaultPreprocessor._sample_foreground_locations(seg, np.arange(1, np.max(seg) + 1), min_percent_coverage=0.50) _verify_class_locations(seg.shape[1:], '/home/isensee/temp/deleteme.nii.gz', a) ================================================ FILE: nnunetv2/preprocessing/resampling/__init__.py ================================================ ================================================ FILE: nnunetv2/preprocessing/resampling/default_resampling.py ================================================ from collections import OrderedDict from copy import deepcopy from typing import Union, Tuple, List import numpy as np import pandas as pd import torch from batchgenerators.augmentations.utils import resize_segmentation from scipy.ndimage import map_coordinates from skimage.transform import resize from nnunetv2.configuration import ANISO_THRESHOLD def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD): do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold return do_separate_z def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]): axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic return axis def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], old_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray: assert len(old_spacing) == len(old_shape) assert len(old_shape) == len(new_spacing) new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)]) return new_shape def determine_do_sep_z_and_axis( force_separate_z: bool, current_spacing, new_spacing, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]: if force_separate_z is not None: do_separate_z = force_separate_z if force_separate_z: axis = get_lowres_axis(current_spacing) else: axis = None else: if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): do_separate_z = True axis = get_lowres_axis(current_spacing) elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): do_separate_z = True axis = get_lowres_axis(new_spacing) else: do_separate_z = False axis = None if axis is not None: if len(axis) == 3: do_separate_z = False axis = None elif len(axis) == 2: # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample # separately in the out of plane axis do_separate_z = False axis = None else: axis = axis[0] return do_separate_z, axis def resample_data_or_seg_to_spacing(data: np.ndarray, current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, order: int = 3, order_z: int = 0, force_separate_z: Union[bool, None] = False, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold) if data is not None: assert data.ndim == 4, "data must be c x y z" shape = np.array(data.shape) new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing) data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) return data_reshaped def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, order: int = 3, order_z: int = 0, force_separate_z: Union[bool, None] = False, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): """ needed for segmentation export. Stupid, I know """ if isinstance(data, torch.Tensor): data = data.numpy() do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold) if data is not None: assert data.ndim == 4, "data must be c x y z" data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) return data_reshaped def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, axis: Union[None, int] = None, order: int = 3, do_separate_z: bool = False, order_z: int = 0, dtype_out = None): """ separate_z=True will resample with order 0 along z :param data: :param new_shape: :param is_seg: :param axis: :param order: :param do_separate_z: :param order_z: only applies if do_separate_z is True :return: """ assert data.ndim == 4, "data must be (c, x, y, z)" assert len(new_shape) == data.ndim - 1 if is_seg: resize_fn = resize_segmentation kwargs = OrderedDict() else: resize_fn = resize kwargs = {'mode': 'edge', 'anti_aliasing': False} shape = np.array(data[0].shape) new_shape = np.array(new_shape) if dtype_out is None: dtype_out = data.dtype reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out) if np.any(shape != new_shape): data = data.astype(float, copy=False) if do_separate_z: assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic' if axis == 0: new_shape_2d = new_shape[1:] elif axis == 1: new_shape_2d = new_shape[[0, 2]] else: new_shape_2d = new_shape[:-1] for c in range(data.shape[0]): tmp = deepcopy(new_shape) tmp[axis] = shape[axis] reshaped_here = np.zeros(tmp) for slice_id in range(shape[axis]): if axis == 0: reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs) elif axis == 1: reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs) else: reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs) if shape[axis] != new_shape[axis]: # The following few lines are blatantly copied and modified from sklearn's resize() rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] orig_rows, orig_cols, orig_dim = reshaped_here.shape # align_corners=False row_scale = float(orig_rows) / rows col_scale = float(orig_cols) / cols dim_scale = float(orig_dim) / dim map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] map_rows = row_scale * (map_rows + 0.5) - 0.5 map_cols = col_scale * (map_cols + 0.5) - 0.5 map_dims = dim_scale * (map_dims + 0.5) - 0.5 coord_map = np.array([map_rows, map_cols, map_dims]) if not is_seg or order_z == 0: reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None] else: unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data) for i, cl in enumerate(unique_labels): reshaped_final[c][np.round( map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z, mode='nearest')) > 0.5] = cl else: reshaped_final[c] = reshaped_here else: for c in range(data.shape[0]): reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs) return reshaped_final else: # print("no resampling necessary") return data if __name__ == '__main__': input_array = np.random.random((1, 42, 231, 142)) output_shape = (52, 256, 256) out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True) print(out.shape, input_array.shape) ================================================ FILE: nnunetv2/preprocessing/resampling/no_resampling.py ================================================ from typing import Union, Tuple, List import numpy as np import torch def no_resampling_hack( data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray] ): return data ================================================ FILE: nnunetv2/preprocessing/resampling/resample_torch.py ================================================ from copy import deepcopy from typing import Union, Tuple, List import numpy as np import torch from einops import rearrange from torch.nn import functional as F from nnunetv2.configuration import ANISO_THRESHOLD from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO from nnunetv2.preprocessing.resampling.default_resampling import determine_do_sep_z_and_axis def resample_torch_simple( data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], is_seg: bool = False, num_threads: int = 4, device: torch.device = torch.device('cpu'), memefficient_seg_resampling: bool = False, mode='linear' ): if mode == 'linear': if data.ndim == 4: torch_mode = 'trilinear' elif data.ndim == 3: torch_mode = 'bilinear' else: raise RuntimeError else: torch_mode = mode if isinstance(new_shape, np.ndarray): new_shape = [int(i) for i in new_shape] if all([i == j for i, j in zip(new_shape, data.shape[1:])]): return data else: n_threads = torch.get_num_threads() torch.set_num_threads(num_threads) new_shape = tuple(new_shape) with torch.no_grad(): input_was_numpy = isinstance(data, np.ndarray) if input_was_numpy: data = torch.from_numpy(data).to(device) else: orig_device = deepcopy(data.device) data = data.to(device) if is_seg: unique_values = torch.unique(data) result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device) if not memefficient_seg_resampling: # believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU) # Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the # uncertain ones be determined by argmax # unique_values = torch.unique(data) # result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16) # for i, u in enumerate(unique_values): # result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0] # result = unique_values[result.argmax(0)] result_tmp = torch.empty((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16, device=device) scale_factor = 1000 done_mask = torch.zeros_like(result, dtype=torch.bool, device=device) for i, u in enumerate(unique_values): result_tmp[i] = \ F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode, antialias=False)[0] mask = result_tmp[i] > (0.7 * scale_factor) result[mask] = u.item() done_mask |= mask if not torch.all(done_mask): # print('resolving argmax', torch.sum(~done_mask), "voxels to go") result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype) else: for i, u in enumerate(unique_values): if u == 0: pass result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[ 0] > 0.5] = u else: result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0] if input_was_numpy: result = result.cpu().numpy() else: result = result.to(orig_device) torch.set_num_threads(n_threads) return result def resample_torch_fornnunet( data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, num_threads: int = 4, device: torch.device = torch.device('cpu'), memefficient_seg_resampling: bool = False, force_separate_z: Union[bool, None] = None, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, mode='linear', aniso_axis_mode='nearest-exact' ): """ data must be c, x, y, z """ assert data.ndim == 4, "data must be c, x, y, z" new_shape = [int(i) for i in new_shape] orig_shape = data.shape do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold) if not isinstance(axis, (tuple, list)): axis = (axis,) # print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis) if do_separate_z: was_numpy = isinstance(data, np.ndarray) if was_numpy: data = torch.from_numpy(data) assert len(axis) == 1 axis = axis[0] tmp = "xyz" axis_letter = tmp[axis] others_int = [i for i in range(3) if i != axis] others = [tmp[i] for i in others_int] # reshape by overloading c channel data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") # reshape in-plane tmp_new_shape = [new_shape[i] for i in others_int] data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device, memefficient_seg_resampling=memefficient_seg_resampling, mode=mode) data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", **{ axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1] } ) # reshape out of plane w/ nearest data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device, memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode) if was_numpy: data = data.numpy() return data else: return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling) if __name__ == '__main__': torch.set_num_threads(16) img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/imagesTr/patient041_frame01_0000.nii.gz' seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/labelsTr/patient041_frame01.nii.gz' io = SimpleITKIO() data, pkl = io.read_images((img_file, )) seg, pkl = io.read_seg(seg_file) target_shape = (15, 256, 312) spacing = pkl['spacing'] use = data is_seg = False ret_nosep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg) ret_sep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg, force_separate_z=False) ================================================ FILE: nnunetv2/preprocessing/resampling/utils.py ================================================ from typing import Callable import nnunetv2 from batchgenerators.utilities.file_and_folder_operations import join from nnunetv2.utilities.find_class_by_name import recursive_find_python_class def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable: ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn, 'nnunetv2.preprocessing.resampling') if ret is None: raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the " "nnunetv2.preprocessing.resampling module." % resampling_fn) else: return ret ================================================ FILE: nnunetv2/run/__init__.py ================================================ ================================================ FILE: nnunetv2/run/load_pretrained_weights.py ================================================ import torch from torch._dynamo import OptimizedModule from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist def load_pretrained_weights(network, fname, verbose=False): """ Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) identified by keys ending with '.seg_layers') are not transferred! If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used, you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as nnUNetTrainer.save_checkpoint takes care of that! """ if dist.is_initialized(): saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False) else: saved_model = torch.load(fname, weights_only=False) pretrained_dict = saved_model['network_weights'] skip_strings_in_pretrained = [ '.seg_layers.', ] if isinstance(network, DDP): mod = network.module else: mod = network if isinstance(mod, OptimizedModule): mod = mod._orig_mod model_dict = mod.state_dict() # verify that all but the segmentation layers have the same shape for key, _ in model_dict.items(): if all([i not in key for i in skip_strings_in_pretrained]): assert key in pretrained_dict, \ f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ f"compatible with your network." assert model_dict[key].shape == pretrained_dict[key].shape, \ f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ f"does not seem to be compatible with your network." # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained # encoders. Not supported by this function though (see assertions above) # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' # pretrained_dict = {'module.' + k if is_ddp else k: v # for k, v in pretrained_dict.items() # if (('module.' + k if is_ddp else k) in model_dict) and # all([i not in k for i in skip_strings_in_pretrained])} pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} model_dict.update(pretrained_dict) print("################### Loading pretrained weights from file ", fname, '###################') if verbose: print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") for key, value in pretrained_dict.items(): print(key, 'shape', value.shape) print("################### Done ###################") mod.load_state_dict(model_dict) ================================================ FILE: nnunetv2/run/run_training.py ================================================ import multiprocessing import os import socket from typing import Union, Optional import nnunetv2 import torch.cuda import torch.distributed as dist import torch.multiprocessing as mp from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json from nnunetv2.paths import nnUNet_preprocessed from nnunetv2.run.load_pretrained_weights import load_pretrained_weights from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from torch.backends import cudnn def find_free_network_port() -> int: """Finds a free port on localhost. It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] s.close() return port def get_trainer_from_args(dataset_name_or_id: Union[int, str], configuration: str, fold: int, trainer_name: str = 'nnUNetTrainer', plans_identifier: str = 'nnUNetPlans', continue_training: bool = False, device: torch.device = torch.device('cuda')): # load nnunet class and do sanity checks nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') if nnunet_trainer is None: raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in ' f'nnunetv2.training.nnUNetTrainer (' f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere ' f'else, please move it there.') assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \ 'nnUNetTrainer' # handle dataset input. If it's an ID we need to convert to int from string if dataset_name_or_id.startswith('Dataset'): pass else: try: dataset_name_or_id = int(dataset_name_or_id) except ValueError: raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern ' f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your ' f'input: {dataset_name_or_id}') # initialize nnunet trainer preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id)) plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json') plans = load_json(plans_file) plans["continue_training"] = continue_training dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json')) nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold, dataset_json=dataset_json, device=device) return nnunet_trainer def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool, pretrained_weights_file: str = None): if continue_training and pretrained_weights_file is not None: raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' 'be used at the beginning of the training.') if continue_training: expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') if not isfile(expected_checkpoint_file): expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') # special case where --c is used to run a previously aborted validation if not isfile(expected_checkpoint_file): expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') if not isfile(expected_checkpoint_file): print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " f"continue from. Starting a new training...") expected_checkpoint_file = None elif validation_only: expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') if not isfile(expected_checkpoint_file): raise RuntimeError(f"Cannot run validation because the training is not finished yet!") else: if pretrained_weights_file is not None: if not nnunet_trainer.was_initialized: nnunet_trainer.initialize() load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True) expected_checkpoint_file = None if expected_checkpoint_file is not None: nnunet_trainer.load_checkpoint(expected_checkpoint_file) def setup_ddp(rank, world_size): # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup_ddp(): dist.destroy_process_group() def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, disable_checkpointing, c, val, pretrained_weights, npz, val_with_best, world_size): setup_ddp(rank, world_size) torch.cuda.set_device(torch.device('cuda', dist.get_rank())) nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, c) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.' maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights) if torch.cuda.is_available(): cudnn.deterministic = False cudnn.benchmark = True if not val: nnunet_trainer.run_training() if val_with_best: nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) nnunet_trainer.perform_actual_validation(npz) cleanup_ddp() def run_training(dataset_name_or_id: Union[str, int], configuration: str, fold: Union[int, str], trainer_class_name: str = 'nnUNetTrainer', plans_identifier: str = 'nnUNetPlans', pretrained_weights: Optional[str] = None, num_gpus: int = 1, export_validation_probabilities: bool = False, continue_training: bool = False, only_run_validation: bool = False, disable_checkpointing: bool = False, val_with_best: bool = False, device: torch.device = torch.device('cuda')): if plans_identifier == 'nnUNetPlans': print("\n############################\n" "INFO: You are using the old nnU-Net default plans. We have updated our recommendations. " "Please consider using those instead! " "Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md" "\n############################\n") if isinstance(fold, str): if fold != 'all': try: fold = int(fold) except ValueError as e: print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') raise e if val_with_best: assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing' if num_gpus > 1: assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}" os.environ['MASTER_ADDR'] = 'localhost' if 'MASTER_PORT' not in os.environ.keys(): port = str(find_free_network_port()) print(f"using port {port}") os.environ['MASTER_PORT'] = port # str(port) mp.spawn(run_ddp, args=( dataset_name_or_id, configuration, fold, trainer_class_name, plans_identifier, disable_checkpointing, continue_training, only_run_validation, pretrained_weights, export_validation_probabilities, val_with_best, num_gpus), nprocs=num_gpus, join=True) else: nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, plans_identifier, continue_training, device=device) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.' maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) if torch.cuda.is_available(): cudnn.deterministic = False cudnn.benchmark = True if not only_run_validation: nnunet_trainer.run_training() if val_with_best: nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth')) nnunet_trainer.perform_actual_validation(export_validation_probabilities) def run_training_entry(): import argparse parser = argparse.ArgumentParser() parser.add_argument('dataset_name_or_id', type=str, help="Dataset name or ID to train with") parser.add_argument('configuration', type=str, help="Configuration that should be trained") parser.add_argument('fold', type=str, help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.') parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer') parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans') parser.add_argument('-pretrained_weights', type=str, required=False, default=None, help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only ' 'be used when actually training. Beta. Use with caution.') parser.add_argument('-num_gpus', type=int, default=1, required=False, help='Specify the number of GPUs to use for training') parser.add_argument('--npz', action='store_true', required=False, help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted ' 'segmentations). Needed for finding the best ensemble.') parser.add_argument('--c', action='store_true', required=False, help='[OPTIONAL] Continue training from latest checkpoint') parser.add_argument('--val', action='store_true', required=False, help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.') parser.add_argument('--val_best', action='store_true', required=False, help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead ' 'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! ' 'WARNING: This will use the same \'validation\' folder as the regular validation ' 'with no way of distinguishing the two!') parser.add_argument('--disable_checkpointing', action='store_true', required=False, help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and ' 'you dont want to flood your hard drive with checkpoints.') parser.add_argument('-device', type=str, default='cuda', required=False, help="Use this to set the device the training should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") args = parser.parse_args() assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' if args.device == 'cpu': # let's allow torch to use hella threads torch.set_num_threads(multiprocessing.cpu_count()) device = torch.device('cpu') elif args.device == 'cuda': # multithreading in torch doesn't help nnU-Net if run on GPU torch.set_num_threads(1) torch.set_num_interop_threads(1) device = torch.device('cuda') else: device = torch.device('mps') run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, args.num_gpus, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best, device=device) if __name__ == '__main__': os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' os.environ['OPENBLAS_NUM_THREADS'] = '1' # reduces the number of threads used for compiling. More threads don't help and can cause problems os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' # multiprocessing.set_start_method("spawn") run_training_entry() ================================================ FILE: nnunetv2/tests/__init__.py ================================================ ================================================ FILE: nnunetv2/tests/integration_tests/__init__.py ================================================ ================================================ FILE: nnunetv2/tests/integration_tests/add_lowres_and_cascade.py ================================================ from copy import deepcopy from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.paths import nnUNet_preprocessed from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids') args = parser.parse_args() for d in args.d: dataset_name = maybe_convert_to_dataset_name(d) plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json')) plans['configurations']['3d_lowres'] = { "data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time 'inherits_from': '3d_fullres', "patch_size": [20, 28, 20], "median_image_size_in_voxels": [18.0, 25.0, 18.0], "spacing": [2.0, 2.0, 2.0], "architecture": deepcopy(plans['configurations']['3d_fullres']["architecture"]), "next_stage": "3d_cascade_fullres" } plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_conv_per_stage'] = [2, 2, 2] plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_conv_per_stage_decoder'] = [2, 2] plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['strides'] = [[1, 1, 1], [2, 2, 2], [2, 2, 2]] plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['kernel_sizes'] = [[3, 3, 3], [3, 3, 3], [3, 3, 3]] plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_stages'] = 3 plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['features_per_stage'] = [ 32, 64, 128 ] plans['configurations']['3d_cascade_fullres'] = { 'inherits_from': '3d_fullres', "previous_stage": "3d_lowres" } save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False) ================================================ FILE: nnunetv2/tests/integration_tests/cleanup_integration_test.py ================================================ import shutil from batchgenerators.utilities.file_and_folder_operations import isdir, join from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed if __name__ == '__main__': # deletes everything! dataset_names = [ 'Dataset996_IntegrationTest_Hippocampus_regions_ignore', 'Dataset997_IntegrationTest_Hippocampus_regions', 'Dataset998_IntegrationTest_Hippocampus_ignore', 'Dataset999_IntegrationTest_Hippocampus', ] for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]: for d in dataset_names: if isdir(join(fld, d)): shutil.rmtree(join(fld, d)) ================================================ FILE: nnunetv2/tests/integration_tests/lsf_commands.sh ================================================ bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996" bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997" bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998" bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999" bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996" bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997" bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998" bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999" ================================================ FILE: nnunetv2/tests/integration_tests/prepare_integration_tests.sh ================================================ # assumes you are in the nnunet repo! # prepare raw datasets python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py # now run experiment planning without preprocessing nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp # now add 3d lowres and cascade python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999 # now preprocess everything nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres # done ================================================ FILE: nnunetv2/tests/integration_tests/readme.md ================================================ # Preface I am just a mortal with many tasks and limited time. Aint nobody got time for unittests. HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish. # Introduction - What the heck is happening? This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with ignore labels). It runs the entire nnU-Net pipeline from start to finish: - fingerprint extraction - experiment planning - preprocessing - train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV - automatically find the best model or ensemble - determine the postprocessing used for this - predict some test set - apply postprocessing to the test set To speed things up, we do the following: - pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation - by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more! - we use nnUNetTrainer_5epochs for a short training # How to run it? Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!) Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004! ```commandline bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh ``` Now you can run the integration test for each of the datasets: ```commandline bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID ``` use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up. This will take i dunno like 10-30 Minutes!? Also run ```commandline bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID ``` to verify DDP is working (needs 2 GPUs!) # How to check if the test was successful? If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range. So you need to do the following: 1) check that none of your runs crashed (duh) 2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file. Does it make sense? If so: NICE! Once the integration test is completed you can delete all the temporary files associated with it by running: ```commandline python nnunetv2/tests/integration_tests/cleanup_integration_test.py ``` ================================================ FILE: nnunetv2/tests/integration_tests/run_integration_test.sh ================================================ nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1 ================================================ FILE: nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py ================================================ import argparse import torch from batchgenerators.utilities.file_and_folder_operations import join, load_pickle from nnunetv2.ensembling.ensemble import ensemble_folders from nnunetv2.evaluation.find_best_configuration import find_best_configuration, \ dumb_trainer_config_plans_to_trained_models_dict from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.paths import nnUNet_raw, nnUNet_results from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.file_path_utilities import get_output_folder if __name__ == '__main__': """ Predicts the imagesTs folder with the best configuration and applies postprocessing """ torch.set_num_threads(1) torch.set_num_interop_threads(1) parser = argparse.ArgumentParser() parser.add_argument('-d', type=int, help='dataset id') args = parser.parse_args() d = args.d dataset_name = maybe_convert_to_dataset_name(d) source_dir = join(nnUNet_raw, dataset_name, 'imagesTs') target_dir_base = join(nnUNet_results, dataset_name) models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'], ['2d', '3d_lowres', '3d_cascade_fullres', '3d_fullres'], ['nnUNetPlans']) ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True, folds=(0, 1, 2, 3, 4), strict=True) has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1 # we don't use all folds to speed stuff up used_folds = (0, 3) output_folders = [] for im in ret['best_model_or_ensemble']['selected_model_or_models']: output_dir = join(target_dir_base, f"pred_{im['configuration']}") model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration']) # note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted # twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the # prediction here. Proper way would be to check for that and # then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in # predict_from_raw_data. Since we allow for # dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an # integration test after all. Take a closer look at how this in handled in predict_from_raw_data predictor = nnUNetPredictor(verbose=False, allow_tqdm=False) predictor.initialize_from_trained_model_folder(model_folder, used_folds) predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True) # predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir, # model_training_output_dir=model_folder, use_folds=used_folds, # save_probabilities=has_ensemble, verbose=False, overwrite=True) output_folders.append(output_dir) # if we have an ensemble, we need to ensemble the results if has_ensemble: ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False) folder_for_pp = join(target_dir_base, 'ensemble_predictions') else: folder_for_pp = output_folders[0] # apply postprocessing pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file']) apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'), pp_fns, pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file']) ================================================ FILE: nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh ================================================ nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2 ================================================ FILE: nnunetv2/tests/integration_tests/run_nnunet_inference.py ================================================ import os import shutil import subprocess from pathlib import Path import nibabel as nib import numpy as np def dice_score(y_true, y_pred): intersect = np.sum(y_true * y_pred) denominator = np.sum(y_true) + np.sum(y_pred) f1 = (2 * intersect) / (denominator + 1e-6) return f1 def run_tests_and_exit_on_failure(): """ Runs inference of a simple nnU-Net for CT body segmentation on a small example CT image and checks if the output is correct. """ # Set nnUNet_results env var weights_dir = Path.home() / "github_actions_nnunet" / "results" os.environ["nnUNet_results"] = str(weights_dir) # Copy example file os.makedirs("nnunetv2/tests/github_actions_output", exist_ok=True) shutil.copy("nnunetv2/tests/example_data/example_ct_sm.nii.gz", "nnunetv2/tests/github_actions_output/example_ct_sm_0000.nii.gz") # Run nnunet subprocess.call(f"nnUNetv2_predict -i nnunetv2/tests/github_actions_output -o nnunetv2/tests/github_actions_output -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True) # Check if the nnunet segmentation is correct img_gt = nib.load(f"nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz").get_fdata() img_pred = nib.load(f"nnunetv2/tests/github_actions_output/example_ct_sm.nii.gz").get_fdata() dice = dice_score(img_gt, img_pred) images_equal = dice > 0.99 # allow for a small difference in the segmentation, otherwise the test will fail often assert images_equal, f"The nnunet segmentation is not correct (dice: {dice:.5f})." # Clean up shutil.rmtree("nnunetv2/tests/github_actions_output") shutil.rmtree(Path.home() / "github_actions_nnunet") if __name__ == "__main__": run_tests_and_exit_on_failure() ================================================ FILE: nnunetv2/training/__init__.py ================================================ ================================================ FILE: nnunetv2/training/data_augmentation/__init__.py ================================================ ================================================ FILE: nnunetv2/training/data_augmentation/compute_initial_patch_size.py ================================================ import numpy as np def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): if isinstance(rot_x, (tuple, list)): rot_x = max(np.abs(rot_x)) if isinstance(rot_y, (tuple, list)): rot_y = max(np.abs(rot_y)) if isinstance(rot_z, (tuple, list)): rot_z = max(np.abs(rot_z)) rot_x = min(90 / 360 * 2. * np.pi, rot_x) rot_y = min(90 / 360 * 2. * np.pi, rot_y) rot_z = min(90 / 360 * 2. * np.pi, rot_z) from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d coords = np.array(final_patch_size) final_shape = np.copy(coords) if len(coords) == 3: final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) elif len(coords) == 2: final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) final_shape /= min(scale_range) return final_shape.astype(int) ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/__init__.py ================================================ ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py ================================================ from typing import Union, List, Tuple, Callable import numpy as np from acvl_utils.morphology.morphology_helper import label_with_component_sizes from batchgenerators.transforms.abstract_transforms import AbstractTransform from skimage.morphology import ball from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening class MoveSegAsOneHotToData(AbstractTransform): def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]], key_origin="seg", key_target="data", remove_from_origin=True): """ Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg]. """ self.remove_from_origin = remove_from_origin self.all_labels = all_labels self.key_target = key_target self.key_origin = key_origin self.index_in_origin = index_in_origin def __call__(self, **data_dict): seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1] seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]), dtype=data_dict[self.key_target].dtype) for i, l in enumerate(self.all_labels): seg_onehot[:, i][seg[:, 0] == l] = 1 data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1) if self.remove_from_origin: remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin] data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels] return data_dict class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform): def __init__(self, channel_idx: Union[int, List[int]], key: str = "data", p_per_sample: float = 0.2, fill_with_other_class_p: float = 0.25, dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1): """ Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating misclassification as another class (fill_with_other_class_p) """ self.p_per_label = p_per_label self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent self.fill_with_other_class_p = fill_with_other_class_p self.p_per_sample = p_per_sample self.key = key if not isinstance(channel_idx, (list, tuple)): channel_idx = [channel_idx] self.channel_idx = channel_idx def __call__(self, **data_dict): data = data_dict.get(self.key) for b in range(data.shape[0]): if np.random.uniform() < self.p_per_sample: for c in self.channel_idx: if np.random.uniform() < self.p_per_label: # print(np.unique(data[b, c])) ## should be [0, 1] workon = data[b, c].astype(bool) if not np.any(workon): continue num_voxels = np.prod(workon.shape, dtype=np.uint64) lab, component_sizes = label_with_component_sizes(workon.astype(bool)) if len(component_sizes) > 0: valid_component_ids = [i for i, j in component_sizes.items() if j < num_voxels*self.dont_do_if_covers_more_than_x_percent] # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c, # np.unique(data[b, c]), len(component_sizes), valid_component_ids, # len(valid_component_ids)) if len(valid_component_ids) > 0: random_component = np.random.choice(valid_component_ids) data[b, c][lab == random_component] = 0 if np.random.uniform() < self.fill_with_other_class_p: other_ch = [i for i in self.channel_idx if i != c] if len(other_ch) > 0: other_class = np.random.choice(other_ch) data[b, other_class][lab == random_component] = 1 data_dict[self.key] = data return data_dict class ApplyRandomBinaryOperatorTransform(AbstractTransform): def __init__(self, channel_idx: Union[int, List[int], Tuple[int, ...]], p_per_sample: float = 0.3, any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening), key: str = "data", strel_size: Tuple[int, int] = (1, 10), p_per_label: float = 1): """ Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded segmentation (see for example MoveSegAsOneHotToData) """ self.p_per_label = p_per_label self.strel_size = strel_size self.key = key self.any_of_these = any_of_these self.p_per_sample = p_per_sample if not isinstance(channel_idx, (list, tuple)): channel_idx = [channel_idx] self.channel_idx = channel_idx def __call__(self, **data_dict): for b in range(data_dict[self.key].shape[0]): if np.random.uniform() < self.p_per_sample: # this needs to be applied in random order to the channels np.random.shuffle(self.channel_idx) for c in self.channel_idx: if np.random.uniform() < self.p_per_label: operation = np.random.choice(self.any_of_these) selem = ball(np.random.uniform(*self.strel_size)) workon = data_dict[self.key][b, c].astype(bool) if not np.any(workon): continue # print(np.unique(workon)) res = operation(workon, selem).astype(data_dict[self.key].dtype) # print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res)) data_dict[self.key][b, c] = res # if class was added, we need to remove it in ALL other channels to keep one hot encoding # properties other_ch = [i for i in self.channel_idx if i != c] if len(other_ch) > 0: was_added_mask = (res - workon) > 0 for oc in other_ch: data_dict[self.key][b, oc][was_added_mask] = 0 # if class was removed, leave it at background return data_dict ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py ================================================ from typing import Tuple, Union, List from batchgenerators.augmentations.utils import resize_segmentation from batchgenerators.transforms.abstract_transforms import AbstractTransform import numpy as np class DownsampleSegForDSTransform2(AbstractTransform): ''' data_dict['output_key'] will be a list of segmentations scaled according to ds_scales ''' def __init__(self, ds_scales: Union[List, Tuple], order: int = 0, input_key: str = "seg", output_key: str = "seg", axes: Tuple[int] = None): """ Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape. ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling for each axis independently """ self.axes = axes self.output_key = output_key self.input_key = input_key self.order = order self.ds_scales = ds_scales def __call__(self, **data_dict): if self.axes is None: axes = list(range(2, data_dict[self.input_key].ndim)) else: axes = self.axes output = [] for s in self.ds_scales: if not isinstance(s, (tuple, list)): s = [s] * len(axes) else: assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \ f'for each axis) then the number of entried in that tuple (here ' \ f'{len(s)}) must be the same as the number of axes (here {len(axes)}).' if all([i == 1 for i in s]): output.append(data_dict[self.input_key]) else: new_shape = np.array(data_dict[self.input_key].shape).astype(float) for i, a in enumerate(axes): new_shape[a] *= s[i] new_shape = np.round(new_shape).astype(int) out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype) for b in range(data_dict[self.input_key].shape[0]): for c in range(data_dict[self.input_key].shape[1]): out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order) output.append(out_seg) data_dict[self.output_key] = output return data_dict ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/masking.py ================================================ from typing import List from batchgenerators.transforms.abstract_transforms import AbstractTransform class MaskTransform(AbstractTransform): def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0, data_key: str = "data", seg_key: str = "seg"): """ Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! """ self.apply_to_channels = apply_to_channels self.seg_key = seg_key self.data_key = data_key self.set_outside_to = set_outside_to self.mask_idx_in_seg = mask_idx_in_seg def __call__(self, **data_dict): mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0 for c in self.apply_to_channels: data_dict[self.data_key][:, c][mask] = self.set_outside_to return data_dict ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py ================================================ from typing import List, Tuple, Union from batchgenerators.transforms.abstract_transforms import AbstractTransform import numpy as np class ConvertSegmentationToRegionsTransform(AbstractTransform): def __init__(self, regions: Union[List, Tuple], seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): """ regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, example: regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 :param regions: :param seg_key: :param output_key: """ self.seg_channel = seg_channel self.output_key = output_key self.seg_key = seg_key self.regions = regions def __call__(self, **data_dict): seg = data_dict.get(self.seg_key) if seg is not None: b, c, *shape = seg.shape region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) for region_id, region_labels in enumerate(self.regions): region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels) data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) return data_dict ================================================ FILE: nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py ================================================ from typing import Tuple, Union, List from batchgenerators.transforms.abstract_transforms import AbstractTransform class Convert3DTo2DTransform(AbstractTransform): def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): """ Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel """ self.apply_to_keys = apply_to_keys def __call__(self, **data_dict): for k in self.apply_to_keys: shp = data_dict[k].shape assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.' data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) shape_key = f'orig_shape_{k}' assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \ f'It does that using the {shape_key} key. That key is ' \ f'already taken. Bummer.' data_dict[shape_key] = shp return data_dict class Convert2DTo3DTransform(AbstractTransform): def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): """ Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z) """ self.apply_to_keys = apply_to_keys def __call__(self, **data_dict): for k in self.apply_to_keys: shape_key = f'orig_shape_{k}' assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \ f'Convert2DTo3DTransform only works in tandem with ' \ f'Convert3DTo2DTransform and you probably forgot to add ' \ f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \ f'is where the missing key is generated)' original_shape = data_dict[shape_key] current_shape = data_dict[k].shape data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2], current_shape[-2], current_shape[-1])) return data_dict ================================================ FILE: nnunetv2/training/dataloading/__init__.py ================================================ ================================================ FILE: nnunetv2/training/dataloading/data_loader.py ================================================ import os import warnings from typing import Union, Tuple, List import numpy as np import torch from batchgenerators.dataloading.data_loader import DataLoader from batchgenerators.utilities.file_and_folder_operations import join, load_json from threadpoolctl import threadpool_limits from nnunetv2.paths import nnUNet_preprocessed from nnunetv2.training.dataloading.nnunet_dataset import nnUNetBaseDataset from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2 from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.utilities.plans_handling.plans_handler import PlansManager from acvl_utils.cropping_and_padding.bounding_boxes import crop_and_pad_nd class nnUNetDataLoader(DataLoader): def __init__(self, data: nnUNetBaseDataset, batch_size: int, patch_size: Union[List[int], Tuple[int, ...], np.ndarray], final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray], label_manager: LabelManager, oversample_foreground_percent: float = 0.0, sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None, pad_sides: Union[List[int], Tuple[int, ...]] = None, probabilistic_oversampling: bool = False, transforms=None): """ If we get a 2D patch size, make it pseudo 3D and remember to remove the singleton dimension before returning the batch """ super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities) if len(patch_size) == 2: final_patch_size = (1, *final_patch_size) patch_size = (1, *patch_size) self.patch_size_was_2d = True else: self.patch_size_was_2d = False # this is used by DataLoader for sampling train cases! self.indices = data.identifiers self.oversample_foreground_percent = oversample_foreground_percent self.final_patch_size = final_patch_size self.patch_size = patch_size # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size # (which is what the network will get) these patches will also cover the border of the images self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) if pad_sides is not None: if self.patch_size_was_2d: pad_sides = (0, *pad_sides) for d in range(len(self.need_to_pad)): self.need_to_pad[d] += pad_sides[d] self.num_channels = None self.pad_sides = pad_sides self.sampling_probabilities = sampling_probabilities self.annotated_classes_key = tuple([-1] + label_manager.all_labels) self.has_ignore = label_manager.has_ignore_label self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \ else self._probabilistic_oversampling self.transforms = transforms self.data_shape, self.seg_shape = self.determine_shapes() def _oversample_last_XX_percent(self, sample_idx: int) -> bool: """ determines whether sample sample_idx in a minibatch needs to be guaranteed foreground """ return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) def _probabilistic_oversampling(self, sample_idx: int) -> bool: # print('YEAH BOIIIIII') return np.random.uniform() < self.oversample_foreground_percent def determine_shapes(self): # load one case data, seg, seg_prev, properties = self._data.load_case(self._data.identifiers[0]) num_color_channels = data.shape[0] if self.patch_size_was_2d: spatial_shape = self.final_patch_size[1:] if self.transforms is not None else self.patch_size[1:] else: spatial_shape = self.final_patch_size if self.transforms is not None else self.patch_size data_shape = (self.batch_size, num_color_channels, *spatial_shape) channels_seg = seg.shape[0] if seg_prev is not None: channels_seg += 1 seg_shape = (self.batch_size, channels_seg, *spatial_shape) return data_shape, seg_shape def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None], overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False): # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have # locations for the given slice need_to_pad = self.need_to_pad.copy() dim = len(data_shape) for d in range(dim): # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides # always if need_to_pad[d] + data_shape[d] < self.patch_size[d]: need_to_pad[d] = self.patch_size[d] - data_shape[d] # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we # define what the upper and lower bound can be to then sample form them with np.random.randint lbs = [- need_to_pad[i] // 2 for i in range(dim)] ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)] # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get # at least one of the foreground classes in the patch if not force_fg and not self.has_ignore: bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] # print('I want a random location') else: if not force_fg and self.has_ignore: selected_class = self.annotated_classes_key if len(class_locations[selected_class]) == 0: # no annotated pixels in this case. Not good. But we can hardly skip it here warnings.warn('Warning! No annotated pixels in image!') selected_class = None elif force_fg: assert class_locations is not None, 'if force_fg is set class_locations cannot be None' if overwrite_class is not None: assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \ 'have class_locations (missing key)' # this saves us a np.unique. Preprocessing already did that for all cases. Neat. # class_locations keys can also be tuple eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0] # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list # strange formulation needed to circumvent # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] if any(tmp): if len(eligible_classes_or_regions) > 1: eligible_classes_or_regions.pop(np.where(tmp)[0][0]) if len(eligible_classes_or_regions) == 0: # this only happens if some image does not contain foreground voxels at all selected_class = None if verbose: print('case does not contain any foreground classes') else: # I hate myself. Future me aint gonna be happy to read this # 2022_11_25: had to read it today. Wasn't too bad selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class # print(f'I want to have foreground, selected class: {selected_class}') else: raise RuntimeError('lol what!?') if selected_class is not None: voxels_of_that_class = class_locations[selected_class] selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. # Make sure it is within the bounds of lb and ub # i + 1 because we have first dimension 0! bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)] else: # If the image does not contain any foreground classes, we fall back to random cropping bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)] return bbox_lbs, bbox_ubs def generate_train_batch(self): selected_keys = self.get_indices() # preallocate output tensors in final patch size and write transformed samples directly data_all = torch.empty(self.data_shape, dtype=torch.float32) seg_all = None with torch.no_grad(): with threadpool_limits(limits=1, user_api=None): for j, i in enumerate(selected_keys): # oversampling foreground will improve stability of model training, especially if many patches are empty # (Lung for example) force_fg = self.get_do_oversample(j) data, seg, seg_prev, properties = self._data.load_case(i) # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by # self._data.load_case(i) (see nnUNetDataset.load_case) shape = data.shape[1:] bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations']) bbox = [[i, j] for i, j in zip(bbox_lbs, bbox_ubs)] data_cropped = torch.from_numpy(crop_and_pad_nd(data, bbox, 0)).float() seg_cropped = torch.from_numpy(crop_and_pad_nd(seg, bbox, -1)).to(torch.int16) if seg_prev is not None: seg_prev_cropped = torch.from_numpy(crop_and_pad_nd(seg_prev, bbox, -1)).to(torch.int16) seg_cropped = torch.cat((seg_cropped, seg_prev_cropped[None]), dim=0) if self.patch_size_was_2d: data_cropped = data_cropped[:, 0] seg_cropped = seg_cropped[:, 0] if self.transforms is not None: transformed = self.transforms(**{'image': data_cropped, 'segmentation': seg_cropped}) data_sample = transformed['image'] seg_sample = transformed['segmentation'] else: data_sample = data_cropped seg_sample = seg_cropped data_all[j] = data_sample if isinstance(seg_sample, list): if seg_all is None: seg_all = [torch.empty((self.batch_size, *s.shape), dtype=s.dtype) for s in seg_sample] for s_idx, s in enumerate(seg_sample): seg_all[s_idx][j] = s else: if seg_all is None: seg_all = torch.empty((self.batch_size, *seg_sample.shape), dtype=seg_sample.dtype) seg_all[j] = seg_sample return {'data': data_all, 'target': seg_all, 'keys': selected_keys} if __name__ == '__main__': folder = join(nnUNet_preprocessed, 'Dataset002_Heart', 'nnUNetPlans_3d_fullres') ds = nnUNetDatasetBlosc2(folder) # this should not load the properties! pm = PlansManager(join(folder, os.pardir, 'nnUNetPlans.json')) lm = pm.get_label_manager(load_json(join(folder, os.pardir, 'dataset.json'))) dl = nnUNetDataLoader(ds, 5, (16, 16, 16), (16, 16, 16), lm, 0.33, None, None) a = next(dl) ================================================ FILE: nnunetv2/training/dataloading/nnunet_dataset.py ================================================ import os import warnings from abc import ABC, abstractmethod from copy import deepcopy from functools import lru_cache from typing import List, Union, Type, Tuple import numpy as np import blosc2 import shutil from blosc2 import Filter, Codec from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile, write_pickle, subfiles from nnunetv2.configuration import default_num_processes from nnunetv2.training.dataloading.utils import unpack_dataset import math class nnUNetBaseDataset(ABC): """ Defines the interface """ def __init__(self, folder: str, identifiers: List[str] = None, folder_with_segs_from_previous_stage: str = None): super().__init__() # print('loading dataset') if identifiers is None: identifiers = self.get_identifiers(folder) identifiers.sort() self.source_folder = folder self.folder_with_segs_from_previous_stage = folder_with_segs_from_previous_stage self.identifiers = identifiers def __getitem__(self, identifier): return self.load_case(identifier) @abstractmethod def load_case(self, identifier): pass @staticmethod @abstractmethod def save_case( data: np.ndarray, seg: np.ndarray, properties: dict, output_filename_truncated: str ): pass @staticmethod @abstractmethod def get_identifiers(folder: str) -> List[str]: pass @staticmethod def unpack_dataset(folder: str, overwrite_existing: bool = False, num_processes: int = default_num_processes, verify: bool = True): pass class nnUNetDatasetNumpy(nnUNetBaseDataset): def load_case(self, identifier): data_npy_file = join(self.source_folder, identifier + '.npy') if not isfile(data_npy_file): data = np.load(join(self.source_folder, identifier + '.npz'))['data'] else: data = np.load(data_npy_file, mmap_mode='r') seg_npy_file = join(self.source_folder, identifier + '_seg.npy') if not isfile(seg_npy_file): seg = np.load(join(self.source_folder, identifier + '.npz'))['seg'] else: seg = np.load(seg_npy_file, mmap_mode='r') if self.folder_with_segs_from_previous_stage is not None: prev_seg_npy_file = join(self.folder_with_segs_from_previous_stage, identifier + '.npy') if isfile(prev_seg_npy_file): seg_prev = np.load(prev_seg_npy_file, 'r') else: seg_prev = np.load(join(self.folder_with_segs_from_previous_stage, identifier + '.npz'))['seg'] else: seg_prev = None properties = load_pickle(join(self.source_folder, identifier + '.pkl')) return data, seg, seg_prev, properties @staticmethod def save_case( data: np.ndarray, seg: np.ndarray, properties: dict, output_filename_truncated: str ): np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg) write_pickle(properties, output_filename_truncated + '.pkl') @staticmethod def save_seg( seg: np.ndarray, output_filename_truncated: str ): np.savez_compressed(output_filename_truncated + '.npz', seg=seg) @staticmethod def get_identifiers(folder: str) -> List[str]: """ returns all identifiers in the preprocessed data folder """ case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")] return case_identifiers @staticmethod def unpack_dataset(folder: str, overwrite_existing: bool = False, num_processes: int = default_num_processes, verify: bool = True): return unpack_dataset(folder, True, overwrite_existing, num_processes, verify) class nnUNetDatasetBlosc2(nnUNetBaseDataset): def __init__(self, folder: str, identifiers: List[str] = None, folder_with_segs_from_previous_stage: str = None): super().__init__(folder, identifiers, folder_with_segs_from_previous_stage) blosc2.set_nthreads(1) def __getitem__(self, identifier): return self.load_case(identifier) def load_case(self, identifier): dparams = { 'nthreads': 1 } data_b2nd_file = join(self.source_folder, identifier + '.b2nd') # mmap does not work with Windows -> https://github.com/MIC-DKFZ/nnUNet/issues/2723 mmap_kwargs = {} if os.name == "nt" else {'mmap_mode': 'r'} data = blosc2.open(urlpath=data_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs) seg_b2nd_file = join(self.source_folder, identifier + '_seg.b2nd') seg = blosc2.open(urlpath=seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs) if self.folder_with_segs_from_previous_stage is not None: prev_seg_b2nd_file = join(self.folder_with_segs_from_previous_stage, identifier + '.b2nd') seg_prev = blosc2.open(urlpath=prev_seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs) else: seg_prev = None properties = load_pickle(join(self.source_folder, identifier + '.pkl')) return data, seg, seg_prev, properties @staticmethod def save_case( data: np.ndarray, seg: np.ndarray, properties: dict, output_filename_truncated: str, chunks=None, blocks=None, chunks_seg=None, blocks_seg=None, clevel: int = 8, codec=blosc2.Codec.ZSTD ): blosc2.set_nthreads(1) if chunks_seg is None: chunks_seg = chunks if blocks_seg is None: blocks_seg = blocks cparams = { 'codec': codec, # 'filters': [blosc2.Filter.SHUFFLE], # 'splitmode': blosc2.SplitMode.ALWAYS_SPLIT, 'clevel': clevel, } # print(output_filename_truncated, data.shape, seg.shape, blocks, chunks, blocks_seg, chunks_seg, data.dtype, seg.dtype) blosc2.asarray(np.ascontiguousarray(data), urlpath=output_filename_truncated + '.b2nd', chunks=chunks, blocks=blocks, cparams=cparams) blosc2.asarray(np.ascontiguousarray(seg), urlpath=output_filename_truncated + '_seg.b2nd', chunks=chunks_seg, blocks=blocks_seg, cparams=cparams) write_pickle(properties, output_filename_truncated + '.pkl') @staticmethod def save_seg( seg: np.ndarray, output_filename_truncated: str, chunks_seg=None, blocks_seg=None ): blosc2.asarray(seg, urlpath=output_filename_truncated + '.b2nd', chunks=chunks_seg, blocks=blocks_seg) @staticmethod def get_identifiers(folder: str) -> List[str]: """ returns all identifiers in the preprocessed data folder """ case_identifiers = [i[:-5] for i in os.listdir(folder) if i.endswith(".b2nd") and not i.endswith("_seg.b2nd")] return case_identifiers @staticmethod def unpack_dataset(folder: str, overwrite_existing: bool = False, num_processes: int = default_num_processes, verify: bool = True): pass @staticmethod def comp_blosc2_params( image_size: Tuple[int, int, int, int], patch_size: Union[Tuple[int, int], Tuple[int, int, int]], bytes_per_pixel: int = 4, # 4 byte are float32 l1_cache_size_per_core_in_bytes=32768, # 1 Kibibyte (KiB) = 2^10 Byte; 32 KiB = 32768 Byte l3_cache_size_per_core_in_bytes=1441792, # 1 Mibibyte (MiB) = 2^20 Byte = 1.048.576 Byte; 1.375MiB = 1441792 Byte safety_factor: float = 0.8 # we dont will the caches to the brim. 0.8 means we target 80% of the caches ): """ Computes a recommended block and chunk size for saving arrays with blosc v2. Bloscv2 NDIM doku: "Remember that having a second partition means that we have better flexibility to fit the different partitions at the different CPU cache levels; typically the first partition (aka chunks) should be made to fit in L3 cache, whereas the second partition (aka blocks) should rather fit in L2/L1 caches (depending on whether compression ratio or speed is desired)." (https://www.blosc.org/posts/blosc2-ndim-intro/) -> We are not 100% sure how to optimize for that. For now we try to fit the uncompressed block in L1. This might spill over into L2, which is fine in our books. Note: this is optimized for nnU-Net dataloading where each read operation is done by one core. We cannot use threading Cache default values computed based on old Intel 4110 CPU with 32K L1, 128K L2 and 1408K L3 cache per core. We cannot optimize further for more modern CPUs with more cache as the data will need be be read by the old ones as well. Args: patch_size: Image size, must be 4D (c, x, y, z). For 2D images, make x=1 patch_size: Patch size, spatial dimensions only. So (x, y) or (x, y, z) bytes_per_pixel: Number of bytes per element. Example: float32 -> 4 bytes l1_cache_size_per_core_in_bytes: The size of the L1 cache per core in Bytes. l3_cache_size_per_core_in_bytes: The size of the L3 cache exclusively accessible by each core. Usually the global size of the L3 cache divided by the number of cores. Returns: The recommended block and the chunk size. """ # Fabians code is ugly, but eh num_channels = image_size[0] if len(patch_size) == 2: patch_size = [1, *patch_size] patch_size = np.array(patch_size) block_size = np.array((num_channels, *[2 ** (max(0, math.ceil(math.log2(i)))) for i in patch_size])) # shrink the block size until it fits in L1 estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel while estimated_nbytes_block > (l1_cache_size_per_core_in_bytes * safety_factor): # pick largest deviation from patch_size that is not 1 axis_order = np.argsort(block_size[1:] / patch_size)[::-1] idx = 0 picked_axis = axis_order[idx] while block_size[picked_axis + 1] == 1 or block_size[picked_axis + 1] == 1: idx += 1 picked_axis = axis_order[idx] # now reduce that axis to the next lowest power of 2 block_size[picked_axis + 1] = 2 ** (max(0, math.floor(math.log2(block_size[picked_axis + 1] - 1)))) block_size[picked_axis + 1] = min(block_size[picked_axis + 1], image_size[picked_axis + 1]) estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel block_size = np.array([min(i, j) for i, j in zip(image_size, block_size)]) # note: there is no use extending the chunk size to 3d when we have a 2d patch size! This would unnecessarily # load data into L3 # now tile the blocks into chunks until we hit image_size or the l3 cache per core limit chunk_size = deepcopy(block_size) estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel while estimated_nbytes_chunk < (l3_cache_size_per_core_in_bytes * safety_factor): if patch_size[0] == 1 and all([i == j for i, j in zip(chunk_size[2:], image_size[2:])]): break if all([i == j for i, j in zip(chunk_size, image_size)]): break # find axis that deviates from block_size the most axis_order = np.argsort(chunk_size[1:] / block_size[1:]) idx = 0 picked_axis = axis_order[idx] while chunk_size[picked_axis + 1] == image_size[picked_axis + 1] or patch_size[picked_axis] == 1: idx += 1 picked_axis = axis_order[idx] chunk_size[picked_axis + 1] += block_size[picked_axis + 1] chunk_size[picked_axis + 1] = min(chunk_size[picked_axis + 1], image_size[picked_axis + 1]) estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel if np.mean([i / j for i, j in zip(chunk_size[1:], patch_size)]) > 1.5: # chunk size should not exceed patch size * 1.5 on average chunk_size[picked_axis + 1] -= block_size[picked_axis + 1] break # better safe than sorry chunk_size = [min(i, j) for i, j in zip(image_size, chunk_size)] # print(image_size, chunk_size, block_size) return tuple(block_size), tuple(chunk_size) file_ending_dataset_mapping = { 'npz': nnUNetDatasetNumpy, 'b2nd': nnUNetDatasetBlosc2 } def infer_dataset_class(folder: str) -> Union[Type[nnUNetDatasetBlosc2], Type[nnUNetDatasetNumpy]]: file_endings = set([os.path.basename(i).split('.')[-1] for i in subfiles(folder, join=False)]) if 'pkl' in file_endings: file_endings.remove('pkl') if 'npy' in file_endings: file_endings.remove('npy') assert len(file_endings) == 1, (f'Found more than one file ending in the folder {folder}. ' f'Unable to infer nnUNetDataset variant!') return file_ending_dataset_mapping[list(file_endings)[0]] ================================================ FILE: nnunetv2/training/dataloading/utils.py ================================================ from __future__ import annotations import multiprocessing import os from typing import List from pathlib import Path from warnings import warn import numpy as np from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles from nnunetv2.configuration import default_num_processes def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, verify_npy: bool = False, fail_ctr: int = 0) -> None: data_npy = npz_file[:-3] + "npy" seg_npy = npz_file[:-4] + "_seg.npy" try: npz_content = None # will only be opened on demand if overwrite_existing or not isfile(data_npy): try: npz_content = np.load(npz_file) if npz_content is None else npz_content except Exception as e: print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!") raise e np.save(data_npy, npz_content['data']) if unpack_segmentation and (overwrite_existing or not isfile(seg_npy)): try: npz_content = np.load(npz_file) if npz_content is None else npz_content except Exception as e: print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!") raise e np.save(npz_file[:-4] + "_seg.npy", npz_content['seg']) if verify_npy: try: np.load(data_npy, mmap_mode='r') if isfile(seg_npy): np.load(seg_npy, mmap_mode='r') except ValueError: os.remove(data_npy) os.remove(seg_npy) print(f"Error when checking {data_npy} and {seg_npy}, fixing...") if fail_ctr < 2: _convert_to_npy(npz_file, unpack_segmentation, overwrite_existing, verify_npy, fail_ctr+1) else: raise RuntimeError("Unable to fix unpacking. Please check your system or rerun nnUNetv2_preprocess") except KeyboardInterrupt: if isfile(data_npy): os.remove(data_npy) if isfile(seg_npy): os.remove(seg_npy) raise KeyboardInterrupt def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, num_processes: int = default_num_processes, verify: bool = False): """ all npz files in this folder belong to the dataset, unpack them all """ with multiprocessing.get_context("spawn").Pool(num_processes) as p: npz_files = subfiles(folder, True, None, ".npz", True) p.starmap(_convert_to_npy, zip(npz_files, [unpack_segmentation] * len(npz_files), [overwrite_existing] * len(npz_files), [verify] * len(npz_files)) ) if __name__ == '__main__': unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d') ================================================ FILE: nnunetv2/training/logging/__init__.py ================================================ ================================================ FILE: nnunetv2/training/logging/nnunet_logger.py ================================================ import matplotlib from batchgenerators.utilities.file_and_folder_operations import join matplotlib.use('agg') import seaborn as sns import matplotlib.pyplot as plt from typing import Any from pathlib import Path import shutil import os try: import wandb except ImportError: wandb = None def get_cluster_job_id(): job_id = None if "LSB_JOBID" in os.environ: job_id = os.environ["LSB_JOBID"] if "SLURM_JOB_ID" in os.environ: job_id = os.environ["SLURM_JOB_ID"] return job_id class MetaLogger(object): """A meta logger that bundles multiple loggers behind a single interface. The default configuration includes a local logger used for reading values, plotting progress, and checkpointing. """ def __init__(self, output_folder, resume, verbose: bool = False): """Initialize the meta logger. Args: output_folder: The output folder. resume: Whether to resume training if possible. verbose: Whether to enable verbose logging in the local logger. """ self.output_folder = output_folder self.resume = resume self.loggers = [] self.local_logger = LocalLogger(verbose) if self._is_logger_enabled("nnUNet_wandb_enabled"): self.loggers.append(WandbLogger(output_folder, resume)) def update_config(self, config: dict): """Add a new or update an existing experiment configuration to the logger. Args: config: Logger configuration options. """ for logger in self.loggers: logger.update_config(config) def log(self, key: str, value: Any, step: int): """Log a value for a given step. Args: key: Metric or field name. value: Value to log. step: Step index (typically epoch). """ self.local_logger.log(key, value, step) if isinstance(value, list): for i, val in enumerate(value): for logger in self.loggers: logger.log(f"{key}/class_{i+1}", val, step) else: for logger in self.loggers: logger.log(key, value, step) # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice if key == 'mean_fg_dice': new_ema_pseudo_dice = self.get_value('ema_fg_dice', step=step-1) * 0.9 + 0.1 * value \ if len(self.get_value('ema_fg_dice', step=None)) > 0 else value self.log('ema_fg_dice', new_ema_pseudo_dice, step) def log_summary(self, key: str, value: Any): """Log a summary value. These are usually values that are not logged every step but only once. This can be for example the final validation Dice. Args: key: Metric or field name. value: Value to summarize. """ for logger in self.loggers: logger.log_summary(key, value) def get_value(self, key: str, step: Any): """Fetch a logged value from the local logger. Args: key: Metric or field name. step: Step index to retrieve, or None to return all values. Returns: The logged value or list of values from the local logger. """ return self.local_logger.get_value(key, step) def plot_progress_png(self, output_folder: str): """Write a progress plot PNG using local logger data. Args: output_folder: Directory where the plot image is saved. """ self.local_logger.plot_progress_png(output_folder) def get_checkpoint(self): """Return the local logger checkpoint data. Returns: The checkpoint payload used to restore logging state. """ return self.local_logger.get_checkpoint() def load_checkpoint(self, checkpoint: dict): """Restore the local logger from a checkpoint payload. Args: checkpoint: Checkpoint data returned by `get_checkpoint`. """ self.local_logger.load_checkpoint(checkpoint) def _is_logger_enabled(self, env_var): env_var_result = str(os.getenv(env_var, "0")) if env_var_result in ("0", "False", "false"): return False elif env_var_result in ("1", "True", "true"): return True else: raise RuntimeError("nnU-Net logger environement variable has the wrong value. Must be '0' (disabled) or '1'(enabled).") class LocalLogger: """ This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a little YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP """ def __init__(self, verbose: bool = False): self.my_fantastic_logging = { 'mean_fg_dice': list(), 'ema_fg_dice': list(), 'dice_per_class_or_region': list(), 'train_losses': list(), 'val_losses': list(), 'lrs': list(), 'epoch_start_timestamps': list(), 'epoch_end_timestamps': list() } self.verbose = verbose # shut up, this logging is great def log(self, key, value, epoch: int): """ sometimes shit gets messed up. We try to catch that here """ assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \ 'This function is only intended to log stuff to lists and to have one entry per epoch' if self.verbose: print(f'logging {key}: {value} for epoch {epoch}') if len(self.my_fantastic_logging[key]) < (epoch + 1): self.my_fantastic_logging[key].append(value) else: assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \ 'lists length is off by more than 1' print(f'maybe some logging issue!? logging {key} and {value}') self.my_fantastic_logging[key][epoch] = value def get_value(self, key, step): if step is not None: return self.my_fantastic_logging[key][step] else: return self.my_fantastic_logging[key] def plot_progress_png(self, output_folder): # we infer the epoch form our internal logging epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1 sns.set(font_scale=2.5) fig, ax_all = plt.subplots(3, 1, figsize=(30, 54)) # regular progress.png as we are used to from previous nnU-Net versions ax = ax_all[0] ax2 = ax.twinx() x_values = list(range(epoch + 1)) ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4) ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4) ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice", linewidth=3) ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)", linewidth=4) ax.set_xlabel("epoch") ax.set_ylabel("loss") ax2.set_ylabel("pseudo dice") ax.legend(loc=(0, 1)) ax2.legend(loc=(0.2, 1)) # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs # clogging up the system) ax = ax_all[1] ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1], self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b', ls='-', label="epoch duration", linewidth=4) ylim = [0] + [ax.get_ylim()[1]] ax.set(ylim=ylim) ax.set_xlabel("epoch") ax.set_ylabel("time [s]") ax.legend(loc=(0, 1)) # learning rate ax = ax_all[2] ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4) ax.set_xlabel("epoch") ax.set_ylabel("learning rate") ax.legend(loc=(0, 1)) plt.tight_layout() fig.savefig(join(output_folder, "progress.png")) plt.close() def get_checkpoint(self): return self.my_fantastic_logging def load_checkpoint(self, checkpoint: dict): self.my_fantastic_logging = checkpoint class WandbLogger: """Weights & Biases logger for nnU-Net training runs. Environment Variables: nnUNet_wandb_enabled: Whether W&B logger is enabled (default: 0 -> Disabled) nnUNet_wandb_project: W&B project name (default: "nnunet"). nnUNet_wandb_mode: W&B mode, e.g. "online" or "offline" (default: "online"). """ def __init__(self, output_folder, resume): """Initialize a W&B run and handle resume behavior. Args: output_folder: Directory where W&B run data is stored. resume: Whether to resume a previous W&B run if present. verbose: Unused verbosity flag (kept for interface compatibility). """ if wandb is None: raise RuntimeError("W&B is not installed. Please install W&B with 'pip install wandb' before using the WandbLogger.") self.output_folder = Path(output_folder) self.resume = resume self.project = os.getenv("nnUNet_wandb_project", "nnunet") self.mode = os.getenv("nnUNet_wandb_mode", "online") wandb_id = None if (self.output_folder / "wandb").is_dir(): if self.resume: wandb_dir = self.output_folder / "wandb" / "latest-run" wandb_filename = next(filename for filename in wandb_dir.iterdir() if filename.suffix == ".wandb") wandb_id = wandb_filename.stem[4:] else: shutil.rmtree(str(self.output_folder / "wandb")) _resume = "allow" if self.resume else "never" self.run = wandb.init(project=self.project, dir=str(self.output_folder), id=wandb_id, mode=self.mode, resume=_resume) self.run.config.update({"JobID": get_cluster_job_id()}) self.wandb_init_step = self.run.step def update_config(self, config: dict): """Update W&B config with training metadata. Args: config: Configuration values to merge into the run config. """ self.run.config.update(config) def log(self, key, value, step: int): """Log a scalar value to W&B. Args: key: Metric or field name. value: Value to log. step: Step index (typically epoch). """ self.log_summary("current_epoch", step) if self.resume and step < self.wandb_init_step: return self.run.log({key: value}, step=step) def log_summary(self, key, value): """Write a summary value to W&B. Args: key: Metric or field name. value: Summary value to store. """ self.run.summary[key] = value ================================================ FILE: nnunetv2/training/loss/__init__.py ================================================ ================================================ FILE: nnunetv2/training/loss/compound_losses.py ================================================ import torch from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss from nnunetv2.utilities.helpers import softmax_helper_dim1 from torch import nn class DC_and_CE_loss(nn.Module): def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None, dice_class=SoftDiceLoss): """ Weights for CE and Dice do not need to sum to one. You can set whatever you want. :param soft_dice_kwargs: :param ce_kwargs: :param aggregate: :param square_dice: :param weight_ce: :param weight_dice: """ super(DC_and_CE_loss, self).__init__() if ignore_label is not None: ce_kwargs['ignore_index'] = ignore_label self.weight_dice = weight_dice self.weight_ce = weight_ce self.ignore_label = ignore_label self.ce = RobustCrossEntropyLoss(**ce_kwargs) self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) def forward(self, net_output: torch.Tensor, target: torch.Tensor): """ target must be b, c, x, y(, z) with c=1 :param net_output: :param target: :return: """ if self.ignore_label is not None: assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ '(DC_and_CE_loss)' mask = target != self.ignore_label # remove ignore label from target, replace with one of the known labels. It doesn't matter because we # ignore gradients in those areas anyway target_dice = torch.where(mask, target, 0) num_fg = mask.sum() else: target_dice = target mask = None dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ if self.weight_dice != 0 else 0 ce_loss = self.ce(net_output, target[:, 0]) \ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 result = self.weight_ce * ce_loss + self.weight_dice * dc_loss return result class DC_and_BCE_loss(nn.Module): def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False, dice_class=MemoryEfficientSoftDiceLoss): """ DO NOT APPLY NONLINEARITY IN YOUR NETWORK! target mut be one hot encoded IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!! :param soft_dice_kwargs: :param bce_kwargs: :param aggregate: """ super(DC_and_BCE_loss, self).__init__() if use_ignore_label: bce_kwargs['reduction'] = 'none' self.weight_dice = weight_dice self.weight_ce = weight_ce self.use_ignore_label = use_ignore_label self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs) def forward(self, net_output: torch.Tensor, target: torch.Tensor): if self.use_ignore_label: # target is one hot encoded here. invert it so that it is True wherever we can compute the loss if target.dtype == torch.bool: mask = ~target[:, -1:] else: mask = (1 - target[:, -1:]).bool() # remove ignore channel now that we have the mask # why did we use clone in the past? Should have documented that... # target_regions = torch.clone(target[:, :-1]) target_regions = target[:, :-1] else: target_regions = target mask = None dc_loss = self.dc(net_output, target_regions, loss_mask=mask) target_regions = target_regions.float() if mask is not None: ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8) else: ce_loss = self.ce(net_output, target_regions) result = self.weight_ce * ce_loss + self.weight_dice * dc_loss return result class DC_and_topk_loss(nn.Module): def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None): """ Weights for CE and Dice do not need to sum to one. You can set whatever you want. :param soft_dice_kwargs: :param ce_kwargs: :param aggregate: :param square_dice: :param weight_ce: :param weight_dice: """ super().__init__() if ignore_label is not None: ce_kwargs['ignore_index'] = ignore_label self.weight_dice = weight_dice self.weight_ce = weight_ce self.ignore_label = ignore_label self.ce = TopKLoss(**ce_kwargs) self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) def forward(self, net_output: torch.Tensor, target: torch.Tensor): """ target must be b, c, x, y(, z) with c=1 :param net_output: :param target: :return: """ if self.ignore_label is not None: assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ '(DC_and_CE_loss)' mask = (target != self.ignore_label).bool() # remove ignore label from target, replace with one of the known labels. It doesn't matter because we # ignore gradients in those areas anyway target_dice = torch.clone(target) target_dice[target == self.ignore_label] = 0 num_fg = mask.sum() else: target_dice = target mask = None dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ if self.weight_dice != 0 else 0 ce_loss = self.ce(net_output, target) \ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 result = self.weight_ce * ce_loss + self.weight_dice * dc_loss return result ================================================ FILE: nnunetv2/training/loss/deep_supervision.py ================================================ from torch import nn class DeepSupervisionWrapper(nn.Module): def __init__(self, loss, weight_factors=None): """ Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then applied to each entry like this: l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... If weights are None, all w will be 1. """ super(DeepSupervisionWrapper, self).__init__() assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" self.weight_factors = tuple(weight_factors) self.loss = loss def forward(self, *args): assert all([isinstance(i, (tuple, list)) for i in args]), \ f"all args must be either tuple or list, got {[type(i) for i in args]}" # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because # this code is executed a lot of times! if self.weight_factors is None: weights = (1, ) * len(args[0]) else: weights = self.weight_factors return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) ================================================ FILE: nnunetv2/training/loss/dice.py ================================================ from typing import Callable import torch from nnunetv2.utilities.ddp_allgather import AllGatherGrad from torch import nn class SoftDiceLoss(nn.Module): def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., ddp: bool = True, clip_tp: float = None): """ """ super(SoftDiceLoss, self).__init__() self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth self.clip_tp = clip_tp self.ddp = ddp def forward(self, x, y, loss_mask=None): shp_x = x.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) if self.ddp and self.batch_dice: tp = AllGatherGrad.apply(tp).sum(0, dtype=torch.float32) fp = AllGatherGrad.apply(fp).sum(0, dtype=torch.float32) fn = AllGatherGrad.apply(fn).sum(0, dtype=torch.float32) if self.clip_tp is not None: tp = torch.clip(tp, min=self.clip_tp , max=None) nominator = 2 * tp denominator = 2 * tp + fp + fn dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8)) if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc class MemoryEfficientSoftDiceLoss(nn.Module): def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., ddp: bool = True): """ saves 1.6 GB on Dataset017 3d_lowres """ super(MemoryEfficientSoftDiceLoss, self).__init__() self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth self.ddp = ddp def forward(self, x, y, loss_mask=None): if self.apply_nonlin is not None: x = self.apply_nonlin(x) # make everything shape (b, c) axes = tuple(range(2, x.ndim)) with torch.no_grad(): if x.ndim != y.ndim: y = y.view((y.shape[0], 1, *y.shape[1:])) if x.shape == y.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = y.to(torch.float32) else: y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.float32) y_onehot.scatter_(1, y.long(), 1) if not self.do_bg: y_onehot = y_onehot[:, 1:] sum_gt = y_onehot.sum(axes, dtype=torch.float32) if loss_mask is None else (y_onehot * loss_mask).sum(axes, dtype=torch.float32) # this one MUST be outside the with torch.no_grad(): context. Otherwise no gradients for you if not self.do_bg: x = x[:, 1:] if loss_mask is None: intersect = (x * y_onehot).sum(axes, dtype=torch.float32) sum_pred = x.sum(axes, dtype=torch.float32) else: intersect = (x * y_onehot * loss_mask).sum(axes, dtype=torch.float32) sum_pred = (x * loss_mask).sum(axes, dtype=torch.float32) if self.batch_dice: if self.ddp: intersect = AllGatherGrad.apply(intersect).sum(0, dtype=torch.float32) sum_pred = AllGatherGrad.apply(sum_pred).sum(0, dtype=torch.float32) sum_gt = AllGatherGrad.apply(sum_gt).sum(0, dtype=torch.float32) intersect = intersect.sum(0, dtype=torch.float32) sum_pred = sum_pred.sum(0, dtype=torch.float32) sum_gt = sum_gt.sum(0, dtype=torch.float32) dc = (2 * intersect + self.smooth) / (sum_gt + sum_pred + float(self.smooth)).clamp_min(1e-8) dc = dc.mean() return -dc def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): """ net_output must be (b, c, x, y(, z))) gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) if mask is provided it must have shape (b, 1, x, y(, z))) :param net_output: :param gt: :param axes: can be (, ) = no summation :param mask: mask must be 1 for valid pixels and 0 for invalid pixels :param square: if True then fp, tp and fn will be squared before summation :return: """ if axes is None: axes = tuple(range(2, net_output.ndim)) with torch.no_grad(): if net_output.ndim != gt.ndim: gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) if net_output.shape == gt.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = gt.to(torch.float32) else: y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.float32) y_onehot.scatter_(1, gt.long(), 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot tn = (1 - net_output) * (1 - y_onehot) if mask is not None: with torch.no_grad(): mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) tp *= mask_here fp *= mask_here fn *= mask_here tn *= mask_here # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram # (using nnUNetv2_train 998 3d_fullres 0) # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) if square: tp = tp ** 2 fp = fp ** 2 fn = fn ** 2 tn = tn ** 2 if len(axes) > 0: tp = tp.sum(dim=axes, keepdim=False, dtype=torch.float32) fp = fp.sum(dim=axes, keepdim=False, dtype=torch.float32) fn = fn.sum(dim=axes, keepdim=False, dtype=torch.float32) tn = tn.sum(dim=axes, keepdim=False, dtype=torch.float32) return tp, fp, fn, tn if __name__ == '__main__': from nnunetv2.utilities.helpers import softmax_helper_dim1 pred = torch.rand((2, 3, 32, 32, 32)) ref = torch.randint(0, 3, (2, 32, 32, 32)) dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) res_old = dl_old(pred, ref) res_new = dl_new(pred, ref) print(res_old, res_new) ================================================ FILE: nnunetv2/training/loss/robust_ce_loss.py ================================================ import torch from torch import nn, Tensor import numpy as np class RobustCrossEntropyLoss(nn.CrossEntropyLoss): """ this is just a compatibility layer because my target tensor is float and has an extra dimension input must be logits, not probabilities! """ def forward(self, input: Tensor, target: Tensor) -> Tensor: if target.ndim == input.ndim: assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long()) class TopKLoss(RobustCrossEntropyLoss): """ input must be logits, not probabilities! """ def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): self.k = k super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) def forward(self, inp, target): target = target[:, 0].long() res = super(TopKLoss, self).forward(inp, target) num_voxels = np.prod(res.shape, dtype=np.int64) res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) return res.mean() ================================================ FILE: nnunetv2/training/lr_scheduler/__init__.py ================================================ ================================================ FILE: nnunetv2/training/lr_scheduler/polylr.py ================================================ from torch.optim.lr_scheduler import _LRScheduler class PolyLRScheduler(_LRScheduler): def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): self.optimizer = optimizer self.initial_lr = initial_lr self.max_steps = max_steps self.exponent = exponent self.ctr = 0 super().__init__(optimizer, current_step if current_step is not None else -1) def step(self, current_step=None): if current_step is None or current_step == -1: current_step = self.ctr self.ctr += 1 new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def get_last_lr(self): return self._last_lr ================================================ FILE: nnunetv2/training/lr_scheduler/warmup.py ================================================ import math import warnings from typing import Optional, cast, List from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, _enable_get_lr_call class Lin_incr_LRScheduler(_LRScheduler): def __init__(self, optimizer, max_lr: float, max_steps: int, current_step: int = None): self.optimizer = optimizer self.max_lr = max_lr self.max_steps = max_steps self.ctr = 0 super().__init__(optimizer, current_step if current_step is not None else -1) def step(self, current_step=None): if current_step is None or current_step == -1: current_step = self.ctr self.ctr += 1 new_lr = self.max_lr / self.max_steps * (1 + current_step) for param_group in self.optimizer.param_groups: param_group["lr"] = new_lr class Lin_incr_offset_LRScheduler(_LRScheduler): def __init__(self, optimizer, max_lr: float, max_steps: int, start_step: int, current_step: int = None): self.optimizer = optimizer self.max_lr = max_lr self.max_steps = max_steps self.start_step = start_step self.ctr = 0 super().__init__(optimizer, current_step if current_step is not None else -1) def step(self, current_step=None): if current_step is None or current_step == -1: current_step = self.ctr self.ctr += 1 new_lr = self.max_lr / self.max_steps * (1 + current_step - self.start_step) for param_group in self.optimizer.param_groups: param_group["lr"] = new_lr class PolyLRScheduler_offset(_LRScheduler): def __init__( self, optimizer, initial_lr: float, max_steps: int, start_step: int, exponent: float = 0.9, current_step: int = None, ): self.optimizer = optimizer self.initial_lr = initial_lr self.max_steps = max_steps - start_step self.start_step = start_step self.exponent = exponent self.ctr = 0 super().__init__(optimizer, current_step if current_step is not None else -1) def step(self, current_step=None): if current_step is None or current_step == -1: current_step = self.ctr self.ctr += 1 current_step = current_step - self.start_step if current_step <= 0: current_step = 0 new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent for param_group in self.optimizer.param_groups: param_group["lr"] = new_lr class CosineAnnealingLR_offset(CosineAnnealingLR): def __init__( self, optimizer: Optimizer, T_max: int, eta_min=0, last_epoch=-1, verbose="deprecated", offset: int = 0 ): self.offset = offset super().__init__( optimizer, T_max, eta_min, last_epoch, verbose, ) def _get_closed_form_lr(self): return [ self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.last_epoch - self.offset) / (self.T_max - self.offset))) / 2 for base_lr in self.base_lrs ] def step(self, epoch: Optional[int] = None): # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"): warnings.warn( "Seems like `optimizer.step()` has been overridden after learning rate scheduler " "initialization. Please, make sure to call `optimizer.step()` before " "`lr_scheduler.step()`. See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning, ) # Just check if there were two first lr_scheduler.step() calls before optimizer.step() elif not getattr(self.optimizer, "_opt_called", False): warnings.warn( "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " "In PyTorch 1.1.0 and later, you should call them in the opposite order: " "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " "will result in PyTorch skipping the first value of the learning rate schedule. " "See more details at " "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning, ) self._step_count += 1 with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 else: self.last_epoch = epoch values = cast(List[float], self._get_closed_form_lr()) for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data if isinstance(param_group["lr"], Tensor): lr_val = lr.item() if isinstance(lr, Tensor) else lr # type: ignore[attr-defined] param_group["lr"].fill_(lr_val) else: param_group["lr"] = lr self._last_lr: List[float] = [group["lr"] for group in self.optimizer.param_groups] ================================================ FILE: nnunetv2/training/nnUNetTrainer/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py ================================================ import inspect import multiprocessing import os import shutil import sys import warnings from copy import deepcopy from datetime import datetime from time import time, sleep from typing import Tuple, Union, List import numpy as np import torch from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p from batchgeneratorsv2.helpers.scalar_type import RandomScalar from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \ RemoveRandomConnectedComponentFromOneHotEncodingTransform from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform from batchgeneratorsv2.transforms.utils.random import RandomTransform from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform from torch import autocast, nn from torch import distributed as dist from torch._dynamo import OptimizedModule from torch.cuda import device_count from torch import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from nnunetv2.inference.sliding_window_prediction import compute_gaussian from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size from nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader from nnunetv2.training.logging.nnunet_logger import MetaLogger from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.utilities.collate_outputs import collate_outputs from nnunetv2.utilities.crossval_split import generate_crossval_split from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager class nnUNetTrainer(object): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): # From https://grugbrain.dev/. Worth a read ya big brains ;-) # apex predator of grug is complexity # complexity bad # say again: # complexity very bad # you say now: # complexity very, very bad # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime # one day code base understandable and grug can get work done, everything good! # next day impossible: complexity demon spirit has entered code and very dangerous situation! # OK OK I am guilty. But I tried. # https://www.osnews.com/images/comics/wtfm.jpg # https://i.pinimg.com/originals/26/b2/50/26b250a738ea4abc7a5af4d42ad93af0.jpg self.is_ddp = dist.is_available() and dist.is_initialized() self.local_rank = 0 if not self.is_ddp else dist.get_rank() self.device = device # print what device we are using if self.is_ddp: # implicitly it's clear that we use cuda in this case print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " f"{dist.get_world_size()}." f"Setting device to {self.device}") self.device = torch.device(type='cuda', index=self.local_rank) else: if self.device.type == 'cuda': # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X self.device = torch.device(type='cuda', index=0) print(f"Using device: {self.device}") # loading and saving this class for continuing from checkpoint should not happen based on pickling. This # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we # need. So let's save the init args self.my_init_kwargs = {} for k in inspect.signature(self.__init__).parameters.keys(): self.my_init_kwargs[k] = locals()[k] ### Saving all the init args into class variables for later access continue_training = plans.pop("continue_training") logger_config = {"plans": plans, "configuration": configuration, "fold": fold, "dataset": dataset_json} self.plans_manager = PlansManager(plans) self.configuration_manager = self.plans_manager.get_configuration(configuration) self.configuration_name = configuration self.dataset_json = dataset_json self.fold = fold ### Setting all the folder names. We need to make sure things don't crash in case we are just running # inference and some of the folders may not be defined! self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ if nnUNet_preprocessed is not None else None self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ if nnUNet_results is not None else None self.output_folder = join(self.output_folder_base, f'fold_{fold}') self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, self.configuration_manager.data_identifier) self.dataset_class = None # -> initialize # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to # be a different configuration in the same plans # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using # "previous_stage" and "next_stage"). Otherwise it won't work! self.is_cascaded = self.configuration_manager.previous_stage_name is not None self.folder_with_segs_from_previous_stage = \ join(nnUNet_results, self.plans_manager.dataset_name, self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ if self.is_cascaded else None ### Some hyperparameters for you to fiddle with self.initial_lr = 1e-2 self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.probabilistic_oversampling = False self.num_iterations_per_epoch = 250 self.num_val_iterations_per_epoch = 50 self.num_epochs = 1000 self.current_epoch = 0 self.enable_deep_supervision = True ### Dealing with labels/regions self.label_manager = self.plans_manager.get_label_manager(dataset_json) # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) # needed for predictions. We do sigmoid in case of (overlapping) regions self.num_input_channels = None # -> self.initialize() self.network = None # -> self.build_network_architecture() self.optimizer = self.lr_scheduler = None # -> self.initialize self.grad_scaler = GradScaler("cuda") if self.device.type == 'cuda' else None self.loss = None # -> self.initialize ### Simple logging. Don't take that away from me! # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning # logging timestamp = datetime.now() maybe_mkdir_p(self.output_folder) self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, timestamp.second)) self.logger = MetaLogger(self.output_folder, continue_training) self.logger.update_config(logger_config) ### placeholders self.dataloader_train = self.dataloader_val = None # see on_train_start ### initializing stuff for remembering things and such self._best_ema = None ### inference things self.inference_allowed_mirroring_axes = None # this variable is set in # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints ### checkpoint saving stuff self.save_every = 50 self.disable_checkpointing = False self.was_initialized = False self.print_to_log_file("\n#######################################################################\n" "Please cite the following paper when using nnU-Net:\n" "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " "Nature methods, 18(2), 203-211.\n" "#######################################################################\n", also_print_to_console=True, add_timestamp=False) def initialize(self): if not self.was_initialized: ## DDP batch size and oversampling can differ between workers and needs adaptation # we need to change the batch size in DDP because we don't use any of those distributed samplers self._set_batch_size_and_oversample() self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json) self.network = self.build_network_architecture( self.configuration_manager.network_arch_class_name, self.configuration_manager.network_arch_init_kwargs, self.configuration_manager.network_arch_init_kwargs_req_import, self.num_input_channels, self.label_manager.num_segmentation_heads, self.enable_deep_supervision ).to(self.device) # compile network for free speedup if self._do_i_compile(): self.print_to_log_file('Using torch.compile...') self.network = torch.compile(self.network) self.optimizer, self.lr_scheduler = self.configure_optimizers() # if ddp, wrap in DDP wrapper if self.is_ddp: self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) self.network = DDP(self.network, device_ids=[self.local_rank]) self.loss = self._build_loss() self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder) # torch 2.2.2 crashes upon compiling CE loss # if self._do_i_compile(): # self.loss = torch.compile(self.loss) self.was_initialized = True logger_config_hparas = { "initial_lr": self.initial_lr, "weight_decay": self.weight_decay, "oversample_foreground_percent": self.oversample_foreground_percent, "probabilistic_oversampling": self.probabilistic_oversampling, "num_iterations_per_epoch": self.num_iterations_per_epoch, "num_val_iterations_per_epoch": self.num_val_iterations_per_epoch, "num_epochs": self.num_epochs, "enable_deep_supervision": self.enable_deep_supervision, "batch_size": self.configuration_manager.batch_size } self.logger.update_config({"hparas": logger_config_hparas}) else: raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " "That should not happen.") def _do_i_compile(self): # new default: compile is enabled! # compile does not work on mps if self.device == torch.device('mps'): if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device") return False # CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable if self.device == torch.device('cpu'): if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): self.print_to_log_file("INFO: torch.compile disabled because device is CPU") return False # default torch.compile doesn't work on windows because there are apparently no triton wheels for it # https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2 if os.name == 'nt': if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If " "you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2") return False if 'nnUNet_compile' not in os.environ.keys(): return True else: return os.environ['nnUNet_compile'].lower() in ('true', '1', 't') def _save_debug_information(self): # saving some debug information if self.local_rank == 0: dct = {} for k in self.__dir__(): if not k.startswith("__"): if not callable(getattr(self, k)) or k in ['loss', ]: dct[k] = str(getattr(self, k)) elif k in ['network', ]: dct[k] = str(getattr(self, k).__class__.__name__) else: # print(k) pass if k in ['dataloader_train', 'dataloader_val']: dl = getattr(self, k) if hasattr(dl, 'generator'): dct[k + '.generator'] = str(dl.generator) if hasattr(dl.generator, 'transforms'): try: dct[k + '.generator.transforms'] = str(dl.generator.transforms) except Exception as e: dct[k + '.generator.transforms'] = f"Could not stringify generator.transforms: {type(e).__name__}: {e}" if hasattr(dl, 'num_processes'): dct[k + '.num_processes'] = str(dl.num_processes) if hasattr(dl, 'transform'): dct[k + '.transform'] = str(dl.transform) import subprocess hostname = subprocess.getoutput(['hostname']) dct['hostname'] = hostname torch_version = torch.__version__ if self.device.type == 'cuda': gpu_name = torch.cuda.get_device_name() dct['gpu_name'] = gpu_name cudnn_version = torch.backends.cudnn.version() else: cudnn_version = 'None' dct['device'] = str(self.device) dct['torch_version'] = torch_version dct['cudnn_version'] = cudnn_version save_json(dct, join(self.output_folder, "debug.json")) @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: """ This is where you build the architecture according to the plans. There is no obligation to use get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what you want. Even ignore the plans and just return something static (as long as it can process the requested patch size) but don't bug us with your bugs arising from fiddling with this :-P This is the function that is called in inference as well! This is needed so that all network architecture variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for training, so if you change the network architecture during training by deriving a new trainer class then inference will know about it). If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: > label_manager = plans_manager.get_label_manager(dataset_json) > label_manager.num_segmentation_heads (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, the number of outputs is != the number of classes. Also there is the ignore label for which no output should be generated. label_manager takes care of all that for you.) """ return get_network_from_plans( architecture_class_name, arch_init_kwargs, arch_init_kwargs_req_import, num_input_channels, num_output_channels, allow_init=True, deep_supervision=enable_deep_supervision) def _get_deep_supervision_scales(self): if self.enable_deep_supervision: deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] else: deep_supervision_scales = None # for train and val_transforms return deep_supervision_scales def _set_batch_size_and_oversample(self): if not self.is_ddp: # set batch size to what the plan says, leave oversample untouched self.batch_size = self.configuration_manager.batch_size else: # batch size is distributed over DDP workers and we need to change oversample_percent for each worker world_size = dist.get_world_size() my_rank = dist.get_rank() global_batch_size = self.configuration_manager.batch_size assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ 'GPUs... Duh.' batch_size_per_GPU = [global_batch_size // world_size] * world_size batch_size_per_GPU = [batch_size_per_GPU[i] + 1 if (batch_size_per_GPU[i] * world_size + i) < global_batch_size else batch_size_per_GPU[i] for i in range(len(batch_size_per_GPU))] assert sum(batch_size_per_GPU) == global_batch_size sample_id_low = 0 if my_rank == 0 else np.sum(batch_size_per_GPU[:my_rank]) sample_id_high = np.sum(batch_size_per_GPU[:my_rank + 1]) # This is how oversampling is determined in DataLoader # round(self.batch_size * (1 - self.oversample_foreground_percent)) # We need to use the same scheme here because an oversample of 0.33 with a batch size of 2 will be rounded # to an oversample of 0.5 (1 sample random, one oversampled). This may get lost if we just numerically # compute oversample oversample = [True if not i < round(global_batch_size * (1 - self.oversample_foreground_percent)) else False for i in range(global_batch_size)] if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): oversample_percent = 0.0 elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): oversample_percent = 1.0 else: oversample_percent = sum(oversample[sample_id_low:sample_id_high]) / batch_size_per_GPU[my_rank] print("worker", my_rank, "oversample", oversample_percent) print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank]) self.batch_size = batch_size_per_GPU[my_rank] self.oversample_foreground_percent = oversample_percent def _build_loss(self): if self.label_manager.has_regions: loss = DC_and_BCE_loss({}, {'batch_dice': self.configuration_manager.batch_dice, 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, use_ignore_label=self.label_manager.ignore_label is not None, dice_class=MemoryEfficientSoftDiceLoss) else: loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) if self._do_i_compile(): loss.dc = torch.compile(loss.dc) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) if self.is_ddp and not self._do_i_compile(): # very strange and stupid interaction. DDP crashes and complains about unused parameters due to # weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff. # Anywho, the simple fix is to set a very low weight to this. weights[-1] = 1e-6 else: weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): """ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. """ patch_size = self.configuration_manager.patch_size dim = len(patch_size) # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) if dim == 2: do_dummy_2d_data_aug = False # todo revisit this parametrization if max(patch_size) / min(patch_size) > 1.5: rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) else: rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) mirror_axes = (0, 1) elif dim == 3: # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad # order of the axes is determined by spacing, not image size do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD if do_dummy_2d_data_aug: # why do we rotate 180 deg here all the time? We should also restrict it rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) else: rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) mirror_axes = (0, 1, 2) else: raise RuntimeError() # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the # old nnunet for now) initial_patch_size = get_patch_size(patch_size[-dim:], rotation_for_DA, rotation_for_DA, rotation_for_DA, (0.85, 1.25)) if do_dummy_2d_data_aug: initial_patch_size[0] = patch_size[0] self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') self.inference_allowed_mirroring_axes = mirror_axes return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): if self.local_rank == 0: timestamp = time() dt_object = datetime.fromtimestamp(timestamp) if add_timestamp: args = (f"{dt_object}:", *args) successful = False max_attempts = 5 ctr = 0 while not successful and ctr < max_attempts: try: with open(self.log_file, 'a+') as f: for a in args: f.write(str(a)) f.write(" ") f.write("\n") successful = True except IOError: print(f"{datetime.fromtimestamp(timestamp)}: failed to log: ", sys.exc_info()) sleep(0.5) ctr += 1 if also_print_to_console: print(*args) elif also_print_to_console: print(*args) def print_plans(self): if self.local_rank == 0: dct = deepcopy(self.plans_manager.plans) del dct['configurations'] self.print_to_log_file(f"\nThis is the configuration used by this " f"training:\nConfiguration name: {self.configuration_name}\n", self.configuration_manager, '\n', add_timestamp=False) self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) def configure_optimizers(self): optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) return optimizer, lr_scheduler def plot_network_architecture(self): if self._do_i_compile(): self.print_to_log_file("Unable to plot network architecture: nnUNet_compile is enabled!") return if self.local_rank == 0: try: # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') # pip install git+https://github.com/saugatkandel/hiddenlayer.git # from torchviz import make_dot # # not viable. # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, # *self.configuration_manager.patch_size), # device=self.device)))).render( # join(self.output_folder, "network_architecture.pdf"), format='pdf') # self.optimizer.zero_grad() # broken. import hiddenlayer as hl g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.configuration_manager.patch_size), device=self.device), transforms=None) g.save(join(self.output_folder, "network_architecture.pdf")) del g except Exception as e: self.print_to_log_file("Unable to plot network architecture:") self.print_to_log_file(e) # self.print_to_log_file("\nprinting the network instead:\n") # self.print_to_log_file(self.network) # self.print_to_log_file("\n") finally: empty_cache(self.device) def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, so always the same) and save it as splits_final.json file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. :return: """ if self.dataset_class is None: self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder) if self.fold == "all": # if fold==all then we use all images for training and validation case_identifiers = self.dataset_class.get_identifiers(self.preprocessed_dataset_folder) tr_keys = case_identifiers val_keys = tr_keys else: splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") dataset = self.dataset_class(self.preprocessed_dataset_folder, identifiers=None, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") all_keys_sorted = list(np.sort(list(dataset.identifiers))) splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5) save_json(splits, splits_file) else: self.print_to_log_file("Using splits from existing split file:", splits_file) splits = load_json(splits_file) self.print_to_log_file(f"The split file contains {len(splits)} splits.") self.print_to_log_file("Desired fold for training: %d" % self.fold) if self.fold < len(splits): tr_keys = splits[self.fold]['train'] val_keys = splits[self.fold]['val'] self.print_to_log_file("This split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) else: self.print_to_log_file("INFO: You requested fold %d for training but splits " "contain only %d folds. I am now creating a " "random (but seeded) 80:20 split!" % (self.fold, len(splits))) # if we request a fold that is not in the split file, create a random 80:20 split rnd = np.random.RandomState(seed=12345 + self.fold) keys = np.sort(list(dataset.identifiers)) idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) idx_val = [i for i in range(len(keys)) if i not in idx_tr] tr_keys = [keys[i] for i in idx_tr] val_keys = [keys[i] for i in idx_val] self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." % (len(tr_keys), len(val_keys))) if any([i in val_keys for i in tr_keys]): self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' 'splits.json or ignore if this is intentional.') return tr_keys, val_keys def get_tr_and_val_datasets(self): # create dataset split tr_keys, val_keys = self.do_split() # load the datasets for training and validation. Note that we always draw random samples so we really don't # care about distributing training cases across GPUs. dataset_tr = self.dataset_class(self.preprocessed_dataset_folder, tr_keys, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) return dataset_tr, dataset_val def get_dataloaders(self): if self.dataset_class is None: self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder) # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be patch_size = self.configuration_manager.patch_size # needed for deep supervision: how much do we need to downscale the segmentation targets for the different # outputs? deep_supervision_scales = self._get_deep_supervision_scales() ( rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes, ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() # training pipeline tr_transforms = self.get_training_transforms( patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, use_mask_for_norm=self.configuration_manager.use_mask_for_norm, is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, ignore_label=self.label_manager.ignore_label) # validation pipeline val_transforms = self.get_validation_transforms(deep_supervision_scales, is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, ignore_label=self.label_manager.ignore_label) dataset_tr, dataset_val = self.get_tr_and_val_datasets() dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size, initial_patch_size, self.configuration_manager.patch_size, self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent, sampling_probabilities=None, pad_sides=None, transforms=tr_transforms, probabilistic_oversampling=self.probabilistic_oversampling) dl_val = nnUNetDataLoader(dataset_val, self.batch_size, self.configuration_manager.patch_size, self.configuration_manager.patch_size, self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent, sampling_probabilities=None, pad_sides=None, transforms=val_transforms, probabilistic_oversampling=self.probabilistic_oversampling) allowed_num_processes = get_allowed_n_proc_DA() if allowed_num_processes == 0: mt_gen_train = SingleThreadedAugmenter(dl_tr, None) mt_gen_val = SingleThreadedAugmenter(dl_val, None) else: mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=None, num_processes=allowed_num_processes, num_cached=max(6, allowed_num_processes // 2), seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.002) mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val, transform=None, num_processes=max(1, allowed_num_processes // 2), num_cached=max(3, allowed_num_processes // 4), seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.002) # # let's get this party started _ = next(mt_gen_train) _ = next(mt_gen_val) return mt_gen_train, mt_gen_val @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> BasicTransform: transforms = [] if do_dummy_2d_data_aug: ignore_axes = (0,) transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0, p_rotation=0.2, rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1, bg_style_seg_sampling=False # , mode_seg='nearest' ) ) if do_dummy_2d_data_aug: transforms.append(Convert2DTo3DTransform()) transforms.append(RandomTransform( GaussianNoiseTransform( noise_variance=(0, 0.1), p_per_channel=1, synchronize_channels=True ), apply_probability=0.1 )) transforms.append(RandomTransform( GaussianBlurTransform( blur_sigma=(0.5, 1.), synchronize_channels=False, synchronize_axes=False, p_per_channel=0.5, benchmark=True ), apply_probability=0.2 )) transforms.append(RandomTransform( MultiplicativeBrightnessTransform( multiplier_range=BGContrast((0.75, 1.25)), synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( ContrastTransform( contrast_range=BGContrast((0.75, 1.25)), preserve_range=True, synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( SimulateLowResolutionTransform( scale=(0.5, 1), synchronize_channels=False, synchronize_axes=True, ignore_axes=ignore_axes, allowed_channels=None, p_per_channel=0.5 ), apply_probability=0.25 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=1, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.1 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=0, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.3 )) if mirror_axes is not None and len(mirror_axes) > 0: transforms.append( MirrorTransform( allowed_axes=mirror_axes ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): transforms.append(MaskImageTransform( apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], channel_idx_in_seg=0, set_outside_to=0, )) transforms.append( RemoveLabelTansform(-1, 0) ) if is_cascaded: assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' transforms.append( MoveSegAsOneHotToDataTransform( source_channel_idx=1, all_labels=foreground_labels, remove_channel_from_source=True ) ) transforms.append( RandomTransform( ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(foreground_labels), 0)), strel_size=(1, 8), p_per_label=1 ), apply_probability=0.4 ) ) transforms.append( RandomTransform( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(foreground_labels), 0)), fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15, p_per_label=1 ), apply_probability=0.2 ) ) if regions is not None: # the ignore label must also be converted transforms.append( ConvertSegmentationToRegionsTransform( regions=list(regions) + [ignore_label] if ignore_label is not None else regions, channel_in_seg=0 ) ) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) @staticmethod def get_validation_transforms( deep_supervision_scales: Union[List, Tuple, None], is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> BasicTransform: transforms = [] transforms.append( RemoveLabelTansform(-1, 0) ) if is_cascaded: transforms.append( MoveSegAsOneHotToDataTransform( source_channel_idx=1, all_labels=foreground_labels, remove_channel_from_source=True ) ) if regions is not None: # the ignore label must also be converted transforms.append( ConvertSegmentationToRegionsTransform( regions=list(regions) + [ignore_label] if ignore_label is not None else regions, channel_in_seg=0 ) ) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) def set_deep_supervision_enabled(self, enabled: bool): """ This function is specific for the default architecture in nnU-Net. If you change the architecture, there are chances you need to change this as well! """ if self.is_ddp: mod = self.network.module else: mod = self.network if isinstance(mod, OptimizedModule): mod = mod._orig_mod mod.decoder.deep_supervision = enabled def on_train_start(self): if not self.was_initialized: self.initialize() # dataloaders must be instantiated here (instead of __init__) because they need access to the training data # which may not be present when doing inference self.dataloader_train, self.dataloader_val = self.get_dataloaders() maybe_mkdir_p(self.output_folder) # make sure deep supervision is on in the network self.set_deep_supervision_enabled(self.enable_deep_supervision) self.print_plans() empty_cache(self.device) # maybe unpack if self.local_rank == 0: self.dataset_class.unpack_dataset( self.preprocessed_dataset_folder, overwrite_existing=False, num_processes=max(1, round(get_allowed_n_proc_DA() // 2)), verify=True) if self.is_ddp: dist.barrier() # copy plans and dataset.json so that they can be used for restoring everything we need for inference save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) # we don't really need the fingerprint but its still handy to have it with the others shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), join(self.output_folder_base, 'dataset_fingerprint.json')) # produces a pdf in output folder self.plot_network_architecture() self._save_debug_information() # print(f"batch size: {self.batch_size}") # print(f"oversample: {self.oversample_foreground_percent}") def on_train_end(self): # dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards. # This will lead to the wrong current epoch to be stored self.current_epoch -= 1 self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) self.current_epoch += 1 # now we can delete latest if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): os.remove(join(self.output_folder, "checkpoint_latest.pth")) # shut down dataloaders old_stdout = sys.stdout with open(os.devnull, 'w') as f: sys.stdout = f if self.dataloader_train is not None and \ isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_train._finish() if self.dataloader_val is not None and \ isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_val._finish() sys.stdout = old_stdout empty_cache(self.device) self.print_to_log_file("Training done.") def on_train_epoch_start(self): self.network.train() self.lr_scheduler.step(self.current_epoch) self.print_to_log_file('') self.print_to_log_file(f'Epoch {self.current_epoch}') self.print_to_log_file( f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") # lrs are the same for all workers so we don't need to gather them in case of DDP training self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) def train_step(self, batch: dict) -> dict: data = batch['data'] target = batch['target'] data = data.to(self.device, non_blocking=True) if isinstance(target, list): target = [i.to(self.device, non_blocking=True) for i in target] else: target = target.to(self.device, non_blocking=True) self.optimizer.zero_grad(set_to_none=True) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): output = self.network(data) # del data l = self.loss(output, target) if self.grad_scaler is not None: self.grad_scaler.scale(l).backward() self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return {'loss': l.detach().cpu().numpy()} def on_train_epoch_end(self, train_outputs: List[dict]): outputs = collate_outputs(train_outputs) if self.is_ddp: losses_tr = [None for _ in range(dist.get_world_size())] dist.all_gather_object(losses_tr, outputs['loss']) loss_here = np.vstack(losses_tr).mean() else: loss_here = np.mean(outputs['loss']) self.logger.log('train_losses', loss_here, self.current_epoch) def on_validation_epoch_start(self): self.network.eval() def validation_step(self, batch: dict) -> dict: data = batch['data'] target = batch['target'] data = data.to(self.device, non_blocking=True) if isinstance(target, list): target = [i.to(self.device, non_blocking=True) for i in target] else: target = target.to(self.device, non_blocking=True) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): output = self.network(data) del data l = self.loss(output, target) # we only need the output with the highest output resolution (if DS enabled) if self.enable_deep_supervision: output = output[0] target = target[0] # the following is needed for online evaluation. Fake dice (green line) axes = [0] + list(range(2, output.ndim)) if self.label_manager.has_regions: predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() else: # no need for softmax output_seg = output.argmax(1)[:, None] predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16) predicted_segmentation_onehot.scatter_(1, output_seg, 1) del output_seg if self.label_manager.has_ignore_label: if not self.label_manager.has_regions: mask = (target != self.label_manager.ignore_label).float() # CAREFUL that you don't rely on target after this line! target[target == self.label_manager.ignore_label] = 0 else: if target.dtype == torch.bool: mask = ~target[:, -1:] else: mask = 1 - target[:, -1:] # CAREFUL that you don't rely on target after this line! target = target[:, :-1] else: mask = None tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) tp_hard = tp.detach().cpu().numpy() fp_hard = fp.detach().cpu().numpy() fn_hard = fn.detach().cpu().numpy() if not self.label_manager.has_regions: # if we train with regions all segmentation heads predict some kind of foreground. In conventional # (softmax training) there needs tobe one output for the background. We are not interested in the # background Dice # [1:] in order to remove background tp_hard = tp_hard[1:] fp_hard = fp_hard[1:] fn_hard = fn_hard[1:] return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} def on_validation_epoch_end(self, val_outputs: List[dict]): outputs_collated = collate_outputs(val_outputs) tp = np.sum(outputs_collated['tp_hard'], 0) fp = np.sum(outputs_collated['fp_hard'], 0) fn = np.sum(outputs_collated['fn_hard'], 0) if self.is_ddp: world_size = dist.get_world_size() tps = [None for _ in range(world_size)] dist.all_gather_object(tps, tp) tp = np.vstack([i[None] for i in tps]).sum(0) fps = [None for _ in range(world_size)] dist.all_gather_object(fps, fp) fp = np.vstack([i[None] for i in fps]).sum(0) fns = [None for _ in range(world_size)] dist.all_gather_object(fns, fn) fn = np.vstack([i[None] for i in fns]).sum(0) losses_val = [None for _ in range(world_size)] dist.all_gather_object(losses_val, outputs_collated['loss']) loss_here = np.vstack(losses_val).mean() else: loss_here = np.mean(outputs_collated['loss']) global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]] mean_fg_dice = np.nanmean(global_dc_per_class) self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) self.logger.log('val_losses', loss_here, self.current_epoch) def on_epoch_start(self): self.logger.log('epoch_start_timestamps', time(), self.current_epoch) def on_epoch_end(self): self.logger.log('epoch_end_timestamps', time(), self.current_epoch) self.print_to_log_file('train_loss', np.round(self.logger.get_value('train_losses', step=-1), decimals=4)) self.print_to_log_file('val_loss', np.round(self.logger.get_value('val_losses', step=-1), decimals=4)) self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in self.logger.get_value('dice_per_class_or_region', step=-1)]) self.print_to_log_file( f"Epoch time: {np.round(self.logger.get_value('epoch_end_timestamps', step=-1) - self.logger.get_value('epoch_start_timestamps', step=-1), decimals=2)} s") # handling periodic checkpointing current_epoch = self.current_epoch if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this if self._best_ema is None or self.logger.get_value('ema_fg_dice', step=-1) > self._best_ema: self._best_ema = self.logger.get_value('ema_fg_dice', step=-1) self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) if self.local_rank == 0: self.logger.plot_progress_png(self.output_folder) self.current_epoch += 1 def save_checkpoint(self, filename: str) -> None: if self.local_rank == 0: if not self.disable_checkpointing: if self.is_ddp: mod = self.network.module else: mod = self.network if isinstance(mod, OptimizedModule): mod = mod._orig_mod checkpoint = { 'network_weights': mod.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, 'logging': self.logger.get_checkpoint(), '_best_ema': self._best_ema, 'current_epoch': self.current_epoch + 1, 'init_args': self.my_init_kwargs, 'trainer_name': self.__class__.__name__, 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, } torch.save(checkpoint, filename) else: self.print_to_log_file('No checkpoint written, checkpointing is disabled') def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: if not self.was_initialized: self.initialize() if isinstance(filename_or_checkpoint, str): checkpoint = torch.load(filename_or_checkpoint, map_location=self.device, weights_only=False) # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match new_state_dict = {} for k, value in checkpoint['network_weights'].items(): key = k if key not in self.network.state_dict().keys() and key.startswith('module.'): key = key[7:] new_state_dict[key] = value self.my_init_kwargs = checkpoint['init_args'] self.current_epoch = checkpoint['current_epoch'] self.logger.load_checkpoint(checkpoint['logging']) self._best_ema = checkpoint['_best_ema'] self.inference_allowed_mirroring_axes = checkpoint[ 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes # messing with state dict naming schemes. Facepalm. if self.is_ddp: if isinstance(self.network.module, OptimizedModule): self.network.module._orig_mod.load_state_dict(new_state_dict) else: self.network.module.load_state_dict(new_state_dict) else: if isinstance(self.network, OptimizedModule): self.network._orig_mod.load_state_dict(new_state_dict) else: self.network.load_state_dict(new_state_dict) self.optimizer.load_state_dict(checkpoint['optimizer_state']) if self.grad_scaler is not None: if checkpoint['grad_scaler_state'] is not None: self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) def perform_actual_validation(self, save_probabilities: bool = False): self.set_deep_supervision_enabled(False) self.network.eval() if self.is_ddp and self.batch_size == 1 and self.enable_deep_supervision and self._do_i_compile(): self.print_to_log_file("WARNING! batch size is 1 during training and torch.compile is enabled. If you " "encounter crashes in validation then this is because torch.compile forgets " "to trigger a recompilation of the model with deep supervision disabled. " "This causes torch.flip to complain about getting a tuple as input. Just rerun the " "validation with --val (exactly the same as before) and then it will work. " "Why? Because --val triggers nnU-Net to ONLY run validation meaning that the first " "forward pass (where compile is triggered) already has deep supervision disabled. " "This is exactly what we need in perform_actual_validation") predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=self.device, verbose=False, verbose_preprocessing=False, allow_tqdm=False) predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, self.dataset_json, self.__class__.__name__, self.inference_allowed_mirroring_axes) with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: worker_list = [i for i in segmentation_export_pool._pool] validation_output_folder = join(self.output_folder, 'validation') maybe_mkdir_p(validation_output_folder) # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute # the validation keys across the workers. _, val_keys = self.do_split() if self.is_ddp: last_barrier_at_idx = len(val_keys) // dist.get_world_size() - 1 val_keys = val_keys[self.local_rank:: dist.get_world_size()] # we cannot just have barriers all over the place because the number of keys each GPU receives can be # different dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) next_stages = self.configuration_manager.next_stage_names if next_stages is not None: _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] results = [] for i, k in enumerate(dataset_val.identifiers): proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, allowed_num_queued=2) while not proceed: sleep(0.1) proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, allowed_num_queued=2) self.print_to_log_file(f"predicting {k}") data, _, seg_prev, properties = dataset_val.load_case(k) # we do [:] to convert blosc2 to numpy data = data[:] if self.is_cascaded: seg_prev = seg_prev[:] data = np.vstack((data, convert_labelmap_to_one_hot(seg_prev, self.label_manager.foreground_labels, output_dtype=data.dtype))) with warnings.catch_warnings(): # ignore 'The given NumPy array is not writable' warning warnings.simplefilter("ignore") data = torch.from_numpy(data) self.print_to_log_file(f'{k}, shape {data.shape}, rank {self.local_rank}') output_filename_truncated = join(validation_output_folder, k) prediction = predictor.predict_sliding_window_return_logits(data) prediction = prediction.cpu() # this needs to go into background processes results.append( segmentation_export_pool.starmap_async( export_prediction_from_logits, ( (prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, output_filename_truncated, save_probabilities), ) ) ) # for debug purposes # export_prediction_from_logits( # prediction, properties, self.configuration_manager, self.plans_manager, # self.dataset_json, output_filename_truncated, save_probabilities # ) # if needed, export the softmax prediction for the next stage if next_stages is not None: for n in next_stages: next_stage_config_manager = self.plans_manager.get_configuration(n) expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, next_stage_config_manager.data_identifier) # next stage may have a different dataset class, do not use self.dataset_class dataset_class = infer_dataset_class(expected_preprocessed_folder) try: # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented tmp = dataset_class(expected_preprocessed_folder, [k]) d, _, _, _ = tmp.load_case(k) except FileNotFoundError: self.print_to_log_file( f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " f"Run the preprocessing for this configuration first!") continue target_shape = d.shape[1:] output_folder = join(self.output_folder_base, 'predicted_next_stage', n) output_file_truncated = join(output_folder, k) # resample_and_save(prediction, target_shape, output_file_truncated, self.plans_manager, # self.configuration_manager, # properties, # self.dataset_json, # default_num_processes, # dataset_class) results.append(segmentation_export_pool.starmap_async( resample_and_save, ( (prediction, target_shape, output_file_truncated, self.plans_manager, self.configuration_manager, properties, self.dataset_json, default_num_processes, dataset_class), ) )) # if we don't barrier from time to time we will get nccl timeouts for large datasets. Yuck. if self.is_ddp and i < last_barrier_at_idx and (i + 1) % 20 == 0: dist.barrier() _ = [r.get() for r in results] if self.is_ddp: dist.barrier() if self.local_rank == 0: metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), validation_output_folder, join(validation_output_folder, 'summary.json'), self.plans_manager.image_reader_writer_class(), self.dataset_json["file_ending"], self.label_manager.foreground_regions if self.label_manager.has_regions else self.label_manager.foreground_labels, self.label_manager.ignore_label, chill=True, num_processes=default_num_processes * dist.get_world_size() if self.is_ddp else default_num_processes) for label in metrics["mean"]: self.logger.log_summary(f"final_val/class_{label}_dice", metrics["mean"][label]["Dice"]) self.logger.log_summary("final_val/foreground_dice", metrics['foreground_mean']["Dice"]) self.print_to_log_file("Validation complete", also_print_to_console=True) self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) self.set_deep_supervision_enabled(True) compute_gaussian.cache_clear() def run_training(self): self.on_train_start() for epoch in range(self.current_epoch, self.num_epochs): self.on_epoch_start() self.on_train_epoch_start() train_outputs = [] for batch_id in range(self.num_iterations_per_epoch): train_outputs.append(self.train_step(next(self.dataloader_train))) self.on_train_epoch_end(train_outputs) with torch.no_grad(): self.on_validation_epoch_start() val_outputs = [] for batch_id in range(self.num_val_iterations_per_epoch): val_outputs.append(self.validation_step(next(self.dataloader_val))) self.on_validation_epoch_end(val_outputs) self.on_epoch_end() self.on_train_end() ================================================ FILE: nnunetv2/training/nnUNetTrainer/primus/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/primus/primus_trainers.py ================================================ from abc import abstractmethod from typing import List, Tuple, Union import torch from torch import nn, autocast from dynamic_network_architectures.architectures.primus import Primus from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.training.nnUNetTrainer.variants.lr_schedule.nnUNetTrainer_warmup import nnUNetTrainer_warmup from torch.nn.parallel import DistributedDataParallel as DDP from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset from nnunetv2.utilities.helpers import empty_cache, dummy_context ###################################################### # See this paper for information on Primus! # Wald*, T., Roy*, S., Isensee*, F., Ulrich, C., Ziegler, S., Trofimova, D., ... & Maier-Hein, K. (2025). Primus: Enforcing attention usage for 3d medical image segmentation. arXiv preprint arXiv:2503.01835. # * equal contribution ###################################################### class AbstractPrimus(nnUNetTrainer_warmup): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 3e-4 self.weight_decay = 5e-2 self.enable_deep_supervision = False @abstractmethod def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: raise NotImplementedError() def configure_optimizers(self, stage: str = "warmup_all"): assert stage in ["warmup_all", "train"] if self.training_stage == stage: return self.optimizer, self.lr_scheduler if isinstance(self.network, DDP): params = self.network.module.parameters() else: params = self.network.parameters() if stage == "warmup_all": self.print_to_log_file("train whole net, warmup") optimizer = torch.optim.AdamW( params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True ) lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net) self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}") else: self.print_to_log_file("train whole net, default schedule") if self.training_stage == "warmup_all": # we can keep the existing optimizer and don't need to create a new one. This will allow us to keep # the accumulated momentum terms which already point in a useful driection optimizer = self.optimizer else: optimizer = torch.optim.AdamW( params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True, ) lr_scheduler = PolyLRScheduler_offset( optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net ) self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}") self.training_stage = stage empty_cache(self.device) return optimizer, lr_scheduler def train_step(self, batch: dict) -> dict: data = batch["data"] target = batch["target"] data = data.to(self.device, non_blocking=True) if isinstance(target, list): target = [i.to(self.device, non_blocking=True) for i in target] else: target = target.to(self.device, non_blocking=True) self.optimizer.zero_grad(set_to_none=True) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context(): output = self.network(data) # del data l = self.loss(output, target) if self.grad_scaler is not None: self.grad_scaler.scale(l).backward() self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1) self.optimizer.step() return {"loss": l.detach().cpu().numpy()} def set_deep_supervision_enabled(self, enabled: bool): pass class nnUNet_Primus_S_Trainer(AbstractPrimus): def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: # this architecture will crash if the patch size is not divisible by 8! model = Primus( num_input_channels, 396, (8, 8, 8), num_output_channels, 12, 6, self.configuration_manager.patch_size, drop_path_rate=0.2, scale_attn_inner=True, init_values=0.1, ) return model class nnUNet_Primus_B_Trainer(AbstractPrimus): def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: # this architecture will crash if the patch size is not divisible by 8! model = Primus( num_input_channels, 792, (8, 8, 8), num_output_channels, 12, 12, self.configuration_manager.patch_size, drop_path_rate=0.2, scale_attn_inner=True, init_values=0.1, ) return model class nnUNet_Primus_M_Trainer(AbstractPrimus): def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: # this architecture will crash if the patch size is not divisible by 8! model = Primus( num_input_channels, 864, (8, 8, 8), num_output_channels, 16, 12, self.configuration_manager.patch_size, drop_path_rate=0.2, scale_attn_inner=True, init_values=0.1, ) return model class nnUNet_Primus_M_Trainer_BS8(nnUNet_Primus_M_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self.configuration_manager.configuration["batch_size"] = 8 class nnUNet_Primus_M_Trainer_BS8_2e4(nnUNet_Primus_M_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 2e-4 self.configuration_manager.configuration["batch_size"] = 8 class nnUNet_Trainer_BS8(nnUNetTrainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self.configuration_manager.configuration["batch_size"] = 8 class nnUNet_Primus_L_Trainer(AbstractPrimus): def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: # this architecture will crash if the patch size is not divisible by 8! model = Primus( num_input_channels, 1056, (8, 8, 8), num_output_channels, 24, 16, self.configuration_manager.patch_size, drop_path_rate=0.2, scale_attn_inner=True, init_values=0.1, ) return model class _Primus_S_96_BS1(nnUNet_Primus_S_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository plans["configurations"][configuration]["batch_size"] = 1 super().__init__(plans, configuration, fold, dataset_json, device) class _Primus_B_96_BS1(nnUNet_Primus_B_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository plans["configurations"][configuration]["batch_size"] = 1 super().__init__(plans, configuration, fold, dataset_json, device) class _Primus_M_96_BS1(nnUNet_Primus_M_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository plans["configurations"][configuration]["batch_size"] = 1 super().__init__(plans, configuration, fold, dataset_json, device) class _Primus_L_48_BS1(nnUNet_Primus_L_Trainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): plans["configurations"][configuration]["patch_size"] = (48, 48, 48) # As per repository plans["configurations"][configuration]["batch_size"] = 1 super().__init__(plans, configuration, fold, dataset_json, device) ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py ================================================ import subprocess import torch from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from torch import distributed as dist class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." self.disable_checkpointing = True self.num_epochs = 5 assert torch.cuda.is_available(), "This only works on GPU" self.crashed_with_runtime_error = False def perform_actual_validation(self, save_probabilities: bool = False): pass def save_checkpoint(self, filename: str) -> None: # do not trust people to remember that self.disable_checkpointing must be True for this trainer pass def run_training(self): try: super().run_training() except RuntimeError: self.crashed_with_runtime_error = True self.on_train_end() def on_train_end(self): super().on_train_end() if not self.is_ddp or self.local_rank == 0: torch_version = torch.__version__ cudnn_version = torch.backends.cudnn.version() gpu_name = torch.cuda.get_device_name() if self.crashed_with_runtime_error: fastest_epoch = 'Not enough VRAM!' else: epoch_times = [i - j for i, j in zip(self.logger.get_value('epoch_end_timestamps', step=None), self.logger.get_value('epoch_start_timestamps', step=None))] fastest_epoch = min(epoch_times) if self.is_ddp: num_gpus = dist.get_world_size() else: num_gpus = 1 benchmark_result_file = join(self.output_folder, 'benchmark_result.json') if isfile(benchmark_result_file): old_results = load_json(benchmark_result_file) else: old_results = {} # generate some unique key hostname = subprocess.getoutput('hostname') my_key = f"{hostname}__{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__num_gpus_{num_gpus}" old_results[my_key] = { 'torch_version': torch_version, 'cudnn_version': cudnn_version, 'gpu_name': gpu_name, 'fastest_epoch': fastest_epoch, 'num_gpus': num_gpus, 'hostname': hostname } save_json(old_results, join(self.output_folder, 'benchmark_result.json')) ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py ================================================ import torch from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( nnUNetTrainerBenchmark_5epochs, ) from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self._set_batch_size_and_oversample() num_input_channels = determine_num_input_channels( self.plans_manager, self.configuration_manager, self.dataset_json ) patch_size = self.configuration_manager.patch_size dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) if self.enable_deep_supervision: dummy_target = [ torch.round( torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) * max(self.label_manager.all_labels) ) for k in self._get_deep_supervision_scales() ] else: raise NotImplementedError("This trainer does not support deep supervision") self.dummy_batch = {"data": dummy_data, "target": dummy_target} def get_dataloaders(self): return None, None def run_training(self): try: self.on_train_start() for epoch in range(self.current_epoch, self.num_epochs): self.on_epoch_start() self.on_train_epoch_start() train_outputs = [] for batch_id in range(self.num_iterations_per_epoch): train_outputs.append(self.train_step(self.dummy_batch)) self.on_train_epoch_end(train_outputs) with torch.no_grad(): self.on_validation_epoch_start() val_outputs = [] for batch_id in range(self.num_val_iterations_per_epoch): val_outputs.append(self.validation_step(self.dummy_batch)) self.on_validation_epoch_end(val_outputs) self.on_epoch_end() self.on_train_end() except RuntimeError: self.crashed_with_runtime_error = True ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/competitions/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/competitions/aortaseg24.py ================================================ from nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerNoMirroring import nnUNetTrainer_onlyMirror01 from nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerDA5 import nnUNetTrainerDA5 class nnUNetTrainer_onlyMirror01_DA5(nnUNetTrainer_onlyMirror01, nnUNetTrainerDA5): pass ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py ================================================ from typing import List, Union, Tuple import numpy as np import torch from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \ GammaTransform from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform from batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \ GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform from batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \ MirrorTransform from batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \ NumpyToTensor from batchgeneratorsv2.helpers.scalar_type import RandomScalar from torch import autocast from nnunetv2.configuration import ANISO_THRESHOLD from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ DownsampleSegForDSTransform2 from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ ConvertSegmentationToRegionsTransform from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \ Convert2DTo3DTransform from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader from nnunetv2.training.loss.dice import get_tp_fp_fn_tn from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.helpers import dummy_context class TensorToNumpy(AbstractTransform): def __init__(self, keys=None, cast_to=None): """ Converts torch tensors to numpy ndarrays. :param keys: specify keys to be converted. If None then all tensor values will be converted. Can be a key (string) or a list/tuple of keys. :param cast_to: optional numpy dtype as string, e.g. 'float32', 'float16', 'int64', 'bool' """ if keys is not None and not isinstance(keys, (list, tuple)): keys = [keys] self.keys = keys self.cast_to = cast_to def cast(self, array: np.ndarray): if self.cast_to is not None: try: array = array.astype(self.cast_to, copy=False) except TypeError: raise ValueError(f"Unknown value for cast_to: {self.cast_to}") return array def _to_numpy(self, tensor): import torch if isinstance(tensor, torch.Tensor): # Important: detach + move to CPU before numpy conversion array = tensor.detach().cpu().numpy() return self.cast(array) return tensor def __call__(self, **data_dict): import torch if self.keys is None: for key, val in data_dict.items(): if isinstance(val, torch.Tensor): data_dict[key] = self._to_numpy(val) elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val): data_dict[key] = [self._to_numpy(i) for i in val] else: for key in self.keys: val = data_dict[key] if isinstance(val, torch.Tensor): data_dict[key] = self._to_numpy(val) elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val): data_dict[key] = [self._to_numpy(i) for i in val] return data_dict class nnUNetTrainerDA5(nnUNetTrainer): def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): patch_size = self.configuration_manager.patch_size dim = len(patch_size) # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) if dim == 2: do_dummy_2d_data_aug = False # todo revisit this parametrization if max(patch_size) / min(patch_size) > 1.5: rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) else: rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) mirror_axes = (0, 1) elif dim == 3: # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad # order of the axes is determined by spacing, not image size do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD if do_dummy_2d_data_aug: # why do we rotate 180 deg here all the time? We should also restrict it rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) else: rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) mirror_axes = (0, 1, 2) else: raise RuntimeError() # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the # old nnunet for now) initial_patch_size = get_patch_size(patch_size[-dim:], rotation_for_DA, rotation_for_DA, rotation_for_DA, (0.7, 1.43)) if do_dummy_2d_data_aug: initial_patch_size[0] = patch_size[0] self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') self.inference_allowed_mirroring_axes = mirror_axes return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> AbstractTransform: matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size]) valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) tr_transforms = [] tr_transforms.append(TensorToNumpy()) tr_transforms.append(RenameTransform('target', 'seg', True)) if do_dummy_2d_data_aug: ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA, p_rot_per_axis=0.5, do_scale=True, scale=(0.7, 1.43), border_mode_data="constant", border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_scale_per_sample=0.2, p_rot_per_sample=0.4, independent_scale_for_each_axis=True, ) ) if do_dummy_2d_data_aug: tr_transforms.append(Convert2DTo3DTransform()) if np.any(matching_axes > 1): tr_transforms.append( Rot90Transform( (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5 ), ) if np.any(matching_axes > 1): tr_transforms.append( TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5) ) tr_transforms.append(OneOfTransform([ MedianFilterTransform( (2, 8), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ), GaussianBlurTransform((0.3, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5) ])) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append(BrightnessTransform(0, 0.5, per_channel=True, p_per_sample=0.1, p_per_channel=0.5 ) ) tr_transforms.append(OneOfTransform( [ ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=True, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=False, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ] )) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.25, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.15, ignore_axes=ignore_axes ) ) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) if mirror_axes is not None and len(mirror_axes) > 0: tr_transforms.append(MirrorTransform(mirror_axes)) tr_transforms.append( BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size], rectangle_value=np.mean, num_rectangles=(1, 5), force_square=False, p_per_sample=0.4, p_per_channel=0.5 ) ) tr_transforms.append( BrightnessGradientAdditiveTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( LocalGammaTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), _local_gamma_gamma, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( SharpeningTransform( strength=(0.1, 1), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if is_cascaded: if ignore_label is not None: raise NotImplementedError('ignore label not yet supported in cascade') assert foreground_labels is not None, 'We need all_labels for cascade augmentations' use_labels = [i for i in foreground_labels if i != 0] tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data')) tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(use_labels), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8), p_per_label=1)) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(use_labels), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: # the ignore label must also be converted tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] if ignore_label is not None else regions, 'target', 'target')) if deep_supervision_scales is not None: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms @staticmethod def get_validation_transforms( deep_supervision_scales: Union[List, Tuple, None], is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> AbstractTransform: val_transforms = [] val_transforms.append(TensorToNumpy()) val_transforms.append(RenameTransform('target', 'seg', True)) val_transforms.append(RemoveLabelTransform(-1, 0)) if is_cascaded: val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: # the ignore label must also be converted val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] if ignore_label is not None else regions, 'target', 'target')) if deep_supervision_scales is not None: val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) return val_transforms def get_dataloaders(self): # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be patch_size = self.configuration_manager.patch_size # needed for deep supervision: how much do we need to downscale the segmentation targets for the different # outputs? deep_supervision_scales = self._get_deep_supervision_scales() ( rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes, ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() # training pipeline tr_transforms = self.get_training_transforms( patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, use_mask_for_norm=self.configuration_manager.use_mask_for_norm, is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, ignore_label=self.label_manager.ignore_label) # validation pipeline val_transforms = self.get_validation_transforms(deep_supervision_scales, is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, ignore_label=self.label_manager.ignore_label) dataset_tr, dataset_val = self.get_tr_and_val_datasets() # we set transforms=None because this trainer still uses batchgenerators which expects transforms to be passed to dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size, initial_patch_size, self.configuration_manager.patch_size, self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent, sampling_probabilities=None, pad_sides=None, transforms=None, probabilistic_oversampling=self.probabilistic_oversampling) dl_val = nnUNetDataLoader(dataset_val, self.batch_size, self.configuration_manager.patch_size, self.configuration_manager.patch_size, self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent, sampling_probabilities=None, pad_sides=None, transforms=None, probabilistic_oversampling=self.probabilistic_oversampling) allowed_num_processes = get_allowed_n_proc_DA() if allowed_num_processes == 0: mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) else: mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=tr_transforms, num_processes=allowed_num_processes, num_cached=6, seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.02) mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val, transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.02) # # let's get this party started _ = next(mt_gen_train) _ = next(mt_gen_val) return mt_gen_train, mt_gen_val def validation_step(self, batch: dict) -> dict: data = batch['data'] target = batch['target'] data = data.to(self.device, non_blocking=True) if isinstance(target, list): target = [i.to(self.device, non_blocking=True) for i in target] else: target = target.to(self.device, non_blocking=True) # Autocast can be annoying # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): output = self.network(data) del data l = self.loss(output, target) # we only need the output with the highest output resolution (if DS enabled) if self.enable_deep_supervision: output = output[0] target = target[0] # the following is needed for online evaluation. Fake dice (green line) axes = [0] + list(range(2, output.ndim)) if self.label_manager.has_regions: predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() else: # no need for softmax output_seg = output.argmax(1)[:, None] predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16) predicted_segmentation_onehot.scatter_(1, output_seg, 1) del output_seg if self.label_manager.has_ignore_label: if not self.label_manager.has_regions: mask = target != self.label_manager.ignore_label # CAREFUL that you don't rely on target after this line! target[target == self.label_manager.ignore_label] = 0 else: if target.dtype == torch.bool: mask = ~target[:, -1:] else: mask = (1 - target[:, -1:]).bool() # CAREFUL that you don't rely on target after this line! target = target[:, :-1].bool() else: mask = None tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) tp_hard = tp.detach().cpu().numpy() fp_hard = fp.detach().cpu().numpy() fn_hard = fn.detach().cpu().numpy() if not self.label_manager.has_regions: # if we train with regions all segmentation heads predict some kind of foreground. In conventional # (softmax training) there needs tobe one output for the background. We are not interested in the # background Dice # [1:] in order to remove background tp_hard = tp_hard[1:] fp_hard = fp_hard[1:] fn_hard = fn_hard[1:] return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} class nnUNetTrainerDA5ord0(nnUNetTrainerDA5): @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> AbstractTransform: matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size]) valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) tr_transforms = [] tr_transforms.append(RenameTransform('target', 'seg', True)) if do_dummy_2d_data_aug: ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA, p_rot_per_axis=0.5, do_scale=True, scale=(0.7, 1.43), border_mode_data="constant", border_cval_data=0, order_data=0, border_mode_seg="constant", border_cval_seg=-1, order_seg=0, random_crop=False, p_el_per_sample=0.2, p_scale_per_sample=0.2, p_rot_per_sample=0.4, independent_scale_for_each_axis=True, ) ) if do_dummy_2d_data_aug: tr_transforms.append(Convert2DTo3DTransform()) if np.any(matching_axes > 1): tr_transforms.append( Rot90Transform( (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5 ), ) if np.any(matching_axes > 1): tr_transforms.append( TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5) ) tr_transforms.append(OneOfTransform([ MedianFilterTransform( (2, 8), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ), GaussianBlurTransform((0.3, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5) ])) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append(BrightnessTransform(0, 0.5, per_channel=True, p_per_sample=0.1, p_per_channel=0.5 ) ) tr_transforms.append(OneOfTransform( [ ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=True, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=False, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ] )) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.25, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.15, ignore_axes=ignore_axes ) ) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) if mirror_axes is not None and len(mirror_axes) > 0: tr_transforms.append(MirrorTransform(mirror_axes)) tr_transforms.append( BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size], rectangle_value=np.mean, num_rectangles=(1, 5), force_square=False, p_per_sample=0.4, p_per_channel=0.5 ) ) tr_transforms.append( BrightnessGradientAdditiveTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( LocalGammaTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), _local_gamma_gamma, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( SharpeningTransform( strength=(0.1, 1), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if is_cascaded: if ignore_label is not None: raise NotImplementedError('ignore label not yet supported in cascade') assert foreground_labels is not None, 'We need all_labels for cascade augmentations' use_labels = [i for i in foreground_labels if i != 0] tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data')) tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(use_labels), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8), p_per_label=1)) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(use_labels), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: # the ignore label must also be converted tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] if ignore_label is not None else regions, 'target', 'target')) if deep_supervision_scales is not None: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms def _brightnessadditive_localgamma_transform_scale(x, y): return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))) def _brightness_gradient_additive_max_strength(_x, _y): return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5) def _local_gamma_gamma(): return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4) class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> AbstractTransform: matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size]) valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) tr_transforms = [] tr_transforms.append(RenameTransform('target', 'seg', True)) if do_dummy_2d_data_aug: ignore_axes = (0,) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA, p_rot_per_axis=0.5, do_scale=True, scale=(0.7, 1.43), border_mode_data="constant", border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=0, random_crop=False, p_el_per_sample=0.2, p_scale_per_sample=0.2, p_rot_per_sample=0.4, independent_scale_for_each_axis=True, ) ) if do_dummy_2d_data_aug: tr_transforms.append(Convert2DTo3DTransform()) if np.any(matching_axes > 1): tr_transforms.append( Rot90Transform( (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5 ), ) if np.any(matching_axes > 1): tr_transforms.append( TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5) ) tr_transforms.append(OneOfTransform([ MedianFilterTransform( (2, 8), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ), GaussianBlurTransform((0.3, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5) ])) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append(BrightnessTransform(0, 0.5, per_channel=True, p_per_sample=0.1, p_per_channel=0.5 ) ) tr_transforms.append(OneOfTransform( [ ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=True, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ContrastAugmentationTransform( contrast_range=(0.5, 2), preserve_range=False, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5 ), ] )) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.25, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.15, ignore_axes=ignore_axes ) ) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) if mirror_axes is not None and len(mirror_axes) > 0: tr_transforms.append(MirrorTransform(mirror_axes)) tr_transforms.append( BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size], rectangle_value=np.mean, num_rectangles=(1, 5), force_square=False, p_per_sample=0.4, p_per_channel=0.5 ) ) tr_transforms.append( BrightnessGradientAdditiveTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( LocalGammaTransform( _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), _local_gamma_gamma, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 ) ) tr_transforms.append( SharpeningTransform( strength=(0.1, 1), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5 ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if is_cascaded: if ignore_label is not None: raise NotImplementedError('ignore label not yet supported in cascade') assert foreground_labels is not None, 'We need all_labels for cascade augmentations' use_labels = [i for i in foreground_labels if i != 0] tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data')) tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(use_labels), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8), p_per_label=1)) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(use_labels), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: # the ignore label must also be converted tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] if ignore_label is not None else regions, 'target', 'target')) if deep_supervision_scales is not None: tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 10 ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py ================================================ from typing import Union, Tuple, List from batchgeneratorsv2.helpers.scalar_type import RandomScalar from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \ RemoveRandomConnectedComponentFromOneHotEncodingTransform from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform from batchgeneratorsv2.transforms.utils.random import RandomTransform from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA import numpy as np class nnUNetTrainer_DASegOrd0(nnUNetTrainer): @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> BasicTransform: transforms = [] if do_dummy_2d_data_aug: ignore_axes = (0,) transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0, p_rotation=0.2, rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1, bg_style_seg_sampling=False, mode_seg='nearest' ) ) if do_dummy_2d_data_aug: transforms.append(Convert2DTo3DTransform()) transforms.append(RandomTransform( GaussianNoiseTransform( noise_variance=(0, 0.1), p_per_channel=1, synchronize_channels=True ), apply_probability=0.1 )) transforms.append(RandomTransform( GaussianBlurTransform( blur_sigma=(0.5, 1.), synchronize_channels=False, synchronize_axes=False, p_per_channel=0.5, benchmark=True ), apply_probability=0.2 )) transforms.append(RandomTransform( MultiplicativeBrightnessTransform( multiplier_range=BGContrast((0.75, 1.25)), synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( ContrastTransform( contrast_range=BGContrast((0.75, 1.25)), preserve_range=True, synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( SimulateLowResolutionTransform( scale=(0.5, 1), synchronize_channels=False, synchronize_axes=True, ignore_axes=ignore_axes, allowed_channels=None, p_per_channel=0.5 ), apply_probability=0.25 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=1, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.1 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=0, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.3 )) if mirror_axes is not None and len(mirror_axes) > 0: transforms.append( MirrorTransform( allowed_axes=mirror_axes ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): transforms.append(MaskImageTransform( apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], channel_idx_in_seg=0, set_outside_to=0, )) transforms.append( RemoveLabelTansform(-1, 0) ) if is_cascaded: assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' transforms.append( MoveSegAsOneHotToDataTransform( source_channel_idx=1, all_labels=foreground_labels, remove_channel_from_source=True ) ) transforms.append( RandomTransform( ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(foreground_labels), 0)), strel_size=(1, 8), p_per_label=1 ), apply_probability=0.4 ) ) transforms.append( RandomTransform( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(foreground_labels), 0)), fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15, p_per_label=1 ), apply_probability=0.2 ) ) if regions is not None: # the ignore label must also be converted transforms.append( ConvertSegmentationToRegionsTransform( regions=list(regions) + [ignore_label] if ignore_label is not None else regions, channel_in_seg=0 ) ) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer_DASegOrd0): def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py ================================================ from typing import Union, Tuple, List import numpy as np from batchgeneratorsv2.helpers.scalar_type import RandomScalar from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainerNoDA(nnUNetTrainer): @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> BasicTransform: return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels, regions, ignore_label) def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): # we need to disable mirroring here so that no mirroring will be applied in inference! rotation_for_DA, do_dummy_2d_data_aug, _, _ = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None initial_patch_size = self.configuration_manager.patch_size return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py ================================================ from typing import Union, Tuple, List import numpy as np import torch from batchgeneratorsv2.helpers.scalar_type import RandomScalar from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \ RemoveRandomConnectedComponentFromOneHotEncodingTransform from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform from batchgeneratorsv2.transforms.utils.random import RandomTransform from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainerNoMirroring(nnUNetTrainer): def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class nnUNetTrainer_onlyMirror01(nnUNetTrainer): """ Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D """ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() patch_size = self.configuration_manager.patch_size dim = len(patch_size) if dim == 2: mirror_axes = (0, ) else: mirror_axes = (0, 1) self.inference_allowed_mirroring_axes = mirror_axes return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class nnUNetTrainer_onlyMirror01_1500ep(nnUNetTrainer_onlyMirror01): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 1500 class nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01): @staticmethod def get_training_transforms( patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: RandomScalar, deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, use_mask_for_norm: List[bool] = None, is_cascaded: bool = False, foreground_labels: Union[Tuple[int, ...], List[int]] = None, regions: List[Union[List[int], Tuple[int, ...], int]] = None, ignore_label: int = None, ) -> BasicTransform: transforms = [] if do_dummy_2d_data_aug: ignore_axes = (0,) transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0, p_rotation=0.2, rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1, bg_style_seg_sampling=False, mode_seg='nearest' ) ) if do_dummy_2d_data_aug: transforms.append(Convert2DTo3DTransform()) transforms.append(RandomTransform( GaussianNoiseTransform( noise_variance=(0, 0.1), p_per_channel=1, synchronize_channels=True ), apply_probability=0.1 )) transforms.append(RandomTransform( GaussianBlurTransform( blur_sigma=(0.5, 1.), synchronize_channels=False, synchronize_axes=False, p_per_channel=0.5, benchmark=True ), apply_probability=0.2 )) transforms.append(RandomTransform( MultiplicativeBrightnessTransform( multiplier_range=BGContrast((0.75, 1.25)), synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( ContrastTransform( contrast_range=BGContrast((0.75, 1.25)), preserve_range=True, synchronize_channels=False, p_per_channel=1 ), apply_probability=0.15 )) transforms.append(RandomTransform( SimulateLowResolutionTransform( scale=(0.5, 1), synchronize_channels=False, synchronize_axes=True, ignore_axes=ignore_axes, allowed_channels=None, p_per_channel=0.5 ), apply_probability=0.25 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=1, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.1 )) transforms.append(RandomTransform( GammaTransform( gamma=BGContrast((0.7, 1.5)), p_invert_image=0, synchronize_channels=False, p_per_channel=1, p_retain_stats=1 ), apply_probability=0.3 )) if mirror_axes is not None and len(mirror_axes) > 0: transforms.append( MirrorTransform( allowed_axes=mirror_axes ) ) if use_mask_for_norm is not None and any(use_mask_for_norm): transforms.append(MaskImageTransform( apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], channel_idx_in_seg=0, set_outside_to=0, )) transforms.append( RemoveLabelTansform(-1, 0) ) if is_cascaded: assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' transforms.append( MoveSegAsOneHotToDataTransform( source_channel_idx=1, all_labels=foreground_labels, remove_channel_from_source=True ) ) transforms.append( RandomTransform( ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(foreground_labels), 0)), strel_size=(1, 8), p_per_label=1 ), apply_probability=0.4 ) ) transforms.append( RandomTransform( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(foreground_labels), 0)), fill_with_other_class_p=0, dont_do_if_covers_more_than_x_percent=0.15, p_per_label=1 ), apply_probability=0.2 ) ) if regions is not None: # the ignore label must also be converted transforms.append( ConvertSegmentationToRegionsTransform( regions=list(regions) + [ignore_label] if ignore_label is not None else regions, channel_in_seg=0 ) ) if deep_supervision_scales is not None: transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) return ComposeTransforms(transforms) ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainer_noDummy2DDA.py ================================================ from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer import numpy as np class nnUNetTrainer_noDummy2DDA(nnUNetTrainer): def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): do_dummy_2d_data_aug = False patch_size = self.configuration_manager.patch_size dim = len(patch_size) if dim == 2: if max(patch_size) / min(patch_size) > 1.5: rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) else: rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) mirror_axes = (0, 1) elif dim == 3: rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi) mirror_axes = (0, 1, 2) else: raise RuntimeError() # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the # old nnunet for now) initial_patch_size = get_patch_size(patch_size[-dim:], rotation_for_DA, rotation_for_DA, rotation_for_DA, (0.85, 1.25)) self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') self.inference_allowed_mirroring_axes = mirror_axes return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py ================================================ import torch from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss import numpy as np class nnUNetTrainerCELoss(nnUNetTrainer): def _build_loss(self): assert not self.label_manager.has_regions, "regions not supported by this trainer" loss = RobustCrossEntropyLoss( weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 ) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 5 ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py ================================================ import numpy as np import torch from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.helpers import softmax_helper_dim1 class nnUNetTrainerDiceLoss(nnUNetTrainer): def _build_loss(self): loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice, 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer): def _build_loss(self): # set smooth to 0 if self.label_manager.has_regions: loss = DC_and_BCE_loss({}, {'batch_dice': self.configuration_manager.batch_dice, 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp}, use_ignore_label=self.label_manager.ignore_label is not None, dice_class=MemoryEfficientSoftDiceLoss) else: loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py ================================================ from nnunetv2.training.loss.compound_losses import DC_and_topk_loss from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer import numpy as np from nnunetv2.training.loss.robust_ce_loss import TopKLoss class nnUNetTrainerTopk10Loss(nnUNetTrainer): def _build_loss(self): assert not self.label_manager.has_regions, "regions not supported by this trainer" loss = TopKLoss( ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 ) if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): def _build_loss(self): assert not self.label_manager.has_regions, "regions not supported by this trainer" loss = TopKLoss( ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10, label_smoothing=0.1, ) if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): def _build_loss(self): assert not self.label_manager.has_regions, "regions not supported by this trainer" loss = DC_and_topk_loss( {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, {"k": 10, "label_smoothing": 0.0}, weight_ce=1, weight_dice=1, ignore_label=self.label_manager.ignore_label, ) if self.enable_deep_supervision: deep_supervision_scales = self._get_deep_supervision_scales() # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) weights[-1] = 0 # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 weights = weights / weights.sum() # now wrap the loss loss = DeepSupervisionWrapper(loss, weights) return loss ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py ================================================ import torch from torch.optim.lr_scheduler import CosineAnnealingLR from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainerCosAnneal(nnUNetTrainer): def configure_optimizers(self): optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) return optimizer, lr_scheduler ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainer_warmup.py ================================================ from typing import Union import torch from torch._dynamo import OptimizedModule from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from torch.nn.parallel import DistributedDataParallel as DDP from nnunetv2.utilities.helpers import empty_cache class nnUNetTrainer_warmup(nnUNetTrainer): """ Does a warmup of the entire architecture Then does normal training """ def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) #### hyperparameters for warmup self.warmup_duration_whole_net = 50 # lin increase whole network self.num_epochs = 1000 self.training_stage = None # 'warmup_all', 'train' def configure_optimizers(self, stage: str = "warmup_all"): assert stage in ["warmup_all", "train"] if self.training_stage == stage: return self.optimizer, self.lr_scheduler if isinstance(self.network, DDP): params = self.network.module.parameters() else: params = self.network.parameters() if stage == "warmup_all": self.print_to_log_file("train whole net, warmup") optimizer = torch.optim.SGD( params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True ) lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net) self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}") else: self.print_to_log_file("train whole net, default schedule") if self.training_stage == "warmup_all": # we can keep the existing optimizer and don't need to create a new one. This will allow us to keep # the accumulated momentum terms which already point in a useful driection optimizer = self.optimizer else: optimizer = torch.optim.SGD( params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True ) lr_scheduler = PolyLRScheduler_offset( optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net ) self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}") self.training_stage = stage empty_cache(self.device) return optimizer, lr_scheduler def on_train_epoch_start(self): if self.current_epoch == 0: self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all") elif self.current_epoch == self.warmup_duration_whole_net: self.optimizer, self.lr_scheduler = self.configure_optimizers("train") super().on_train_epoch_start() def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: """ We need to overwrite that entire function because we need to fiddle the correct optimizer in between loading the checkpoint and applying the optimizer states. Yuck. """ if not self.was_initialized: self.initialize() if isinstance(filename_or_checkpoint, str): checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match new_state_dict = {} for k, value in checkpoint["network_weights"].items(): key = k if key not in self.network.state_dict().keys() and key.startswith("module."): key = key[7:] new_state_dict[key] = value self.my_init_kwargs = checkpoint["init_args"] self.current_epoch = checkpoint["current_epoch"] self.logger.load_checkpoint(checkpoint["logging"]) self._best_ema = checkpoint["_best_ema"] self.inference_allowed_mirroring_axes = ( checkpoint["inference_allowed_mirroring_axes"] if "inference_allowed_mirroring_axes" in checkpoint.keys() else self.inference_allowed_mirroring_axes ) # messing with state dict naming schemes. Facepalm. if self.is_ddp: if isinstance(self.network.module, OptimizedModule): self.network.module._orig_mod.load_state_dict(new_state_dict) else: self.network.module.load_state_dict(new_state_dict) else: if isinstance(self.network, OptimizedModule): self.network._orig_mod.load_state_dict(new_state_dict) else: self.network.load_state_dict(new_state_dict) # it's fine to do this every time we load because configure_optimizers will be a no-op if the correct optimizer # and lr scheduler are already set up if self.current_epoch < self.warmup_duration_whole_net: self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all") else: self.optimizer, self.lr_scheduler = self.configure_optimizers("train") self.optimizer.load_state_dict(checkpoint["optimizer_state"]) if self.grad_scaler is not None: if checkpoint["grad_scaler_state"] is not None: self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state"]) ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py ================================================ from typing import Union, Tuple, List from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm from torch import nn from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainerBN(nnUNetTrainer): @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: if 'norm_op' not in arch_init_kwargs.keys(): raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture " "I can hack BN into. This trainer only works with default nnU-Net architectures.") from pydoc import locate conv_op = locate(arch_init_kwargs['conv_op']) bn_class = get_matching_batchnorm(conv_op) arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__ arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} return nnUNetTrainer.build_network_architecture(architecture_class_name, arch_init_kwargs, arch_init_kwargs_req_import, num_input_channels, num_output_channels, enable_deep_supervision) ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py ================================================ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer import torch class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): def __init__( self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device("cuda"), ): super().__init__(plans, configuration, fold, dataset_json, device) self.enable_deep_supervision = False ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py ================================================ import torch from torch.optim import Adam, AdamW from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainerAdam(nnUNetTrainer): def configure_optimizers(self): optimizer = AdamW(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, amsgrad=True) # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, # momentum=0.99, nesterov=True) lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) return optimizer, lr_scheduler class nnUNetTrainerVanillaAdam(nnUNetTrainer): def configure_optimizers(self): optimizer = Adam(self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay) # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, # momentum=0.99, nesterov=True) lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) return optimizer, lr_scheduler class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 1e-3 class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam): # https://twitter.com/karpathy/status/801621764144971776?lang=en def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 3e-4 class nnUNetTrainerAdam1en3(nnUNetTrainerAdam): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 1e-3 class nnUNetTrainerAdam3en4(nnUNetTrainerAdam): # https://twitter.com/karpathy/status/801621764144971776?lang=en def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 3e-4 ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py ================================================ import torch from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from torch.optim.lr_scheduler import CosineAnnealingLR try: from adan_pytorch import Adan except ImportError: Adan = None class nnUNetTrainerAdan(nnUNetTrainer): def configure_optimizers(self): if Adan is None: raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') optimizer = Adan(self.network.parameters(), lr=self.initial_lr, # betas=(0.02, 0.08, 0.01), defaults weight_decay=self.weight_decay) # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, # momentum=0.99, nesterov=True) lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) return optimizer, lr_scheduler class nnUNetTrainerAdan1en3(nnUNetTrainerAdan): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 1e-3 class nnUNetTrainerAdan3en4(nnUNetTrainerAdan): # https://twitter.com/karpathy/status/801621764144971776?lang=en def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 3e-4 class nnUNetTrainerAdan1en1(nnUNetTrainerAdan): # this trainer makes no sense -> nan! def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.initial_lr = 1e-1 class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan): # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, # device: torch.device = torch.device('cuda')): # super().__init__(plans, configuration, fold, dataset_json, device) # self.num_epochs = 15 def configure_optimizers(self): if Adan is None: raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') optimizer = Adan(self.network.parameters(), lr=self.initial_lr, # betas=(0.02, 0.08, 0.01), defaults weight_decay=self.weight_decay) # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, # momentum=0.99, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) return optimizer, lr_scheduler ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py ================================================ import numpy as np import torch from torch import distributed as dist from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer): """ sampling of foreground happens randomly and not for the last 33% of samples in a batch since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can be 50% Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible. If we switch to this oversampling then we can keep it at a constant 0.33 or whatever. """ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.probabilistic_oversampling = True self.oversample_foreground_percent = float(np.mean( [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent)) for sample_idx in range(self.configuration_manager.batch_size)])) self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}") def _set_batch_size_and_oversample(self): if not self.is_ddp: # set batch size to what the plan says, leave oversample untouched self.batch_size = self.configuration_manager.batch_size else: # batch size is distributed over DDP workers and we need to change oversample_percent for each worker world_size = dist.get_world_size() my_rank = dist.get_rank() global_batch_size = self.configuration_manager.batch_size assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ 'GPUs... Duh.' batch_size_per_GPU = [global_batch_size // world_size] * world_size batch_size_per_GPU = [batch_size_per_GPU[i] + 1 if (batch_size_per_GPU[i] * world_size + i) < global_batch_size else batch_size_per_GPU[i] for i in range(len(batch_size_per_GPU))] assert sum(batch_size_per_GPU) == global_batch_size print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank]) print("worker", my_rank, "oversample", self.oversample_foreground_percent) self.batch_size = batch_size_per_GPU[my_rank] class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.oversample_foreground_percent = 0.33 class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.oversample_foreground_percent = 0.1 ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py ================================================ ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py ================================================ import torch from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainer_5epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 5 class nnUNetTrainer_1epoch(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 1 class nnUNetTrainer_10epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 10 class nnUNetTrainer_20epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 20 class nnUNetTrainer_50epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 50 class nnUNetTrainer_100epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 100 class nnUNetTrainer_250epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 250 class nnUNetTrainer_500epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 500 class nnUNetTrainer_750epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 750 class nnUNetTrainer_2000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 2000 class nnUNetTrainer_4000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 4000 class nnUNetTrainer_8000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 8000 ================================================ FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py ================================================ import torch from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 250 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 2000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 4000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, device) self.num_epochs = 8000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() mirror_axes = None self.inference_allowed_mirroring_axes = None return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes ================================================ FILE: nnunetv2/utilities/__init__.py ================================================ ================================================ FILE: nnunetv2/utilities/collate_outputs.py ================================================ from typing import List import numpy as np def collate_outputs(outputs: List[dict]): """ used to collate default train_step and validation_step outputs. If you want something different then you gotta extend this we expect outputs to be a list of dictionaries where each of the dict has the same set of keys """ collated = {} for k in outputs[0].keys(): if np.isscalar(outputs[0][k]): collated[k] = [o[k] for o in outputs] elif isinstance(outputs[0][k], np.ndarray): collated[k] = np.vstack([o[k][None] for o in outputs]) elif isinstance(outputs[0][k], list): collated[k] = [item for o in outputs for item in o[k]] else: raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' f'Modify collate_outputs to add this functionality') return collated ================================================ FILE: nnunetv2/utilities/crossval_split.py ================================================ from typing import List import numpy as np from sklearn.model_selection import KFold def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]: splits = [] kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)): train_keys = np.array(train_identifiers)[train_idx] test_keys = np.array(train_identifiers)[test_idx] splits.append({}) splits[-1]['train'] = list(train_keys) splits[-1]['val'] = list(test_keys) return splits ================================================ FILE: nnunetv2/utilities/dataset_name_id_conversion.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Union from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results from batchgenerators.utilities.file_and_folder_operations import * import numpy as np def find_candidate_datasets(dataset_id: int): startswith = "Dataset%03.0d" % dataset_id if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed): candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False) else: candidates_preprocessed = [] if nnUNet_raw is not None and isdir(nnUNet_raw): candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False) else: candidates_raw = [] candidates_trained_models = [] if nnUNet_results is not None and isdir(nnUNet_results): candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False) all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models unique_candidates = np.unique(all_candidates) return unique_candidates def convert_id_to_dataset_name(dataset_id: int): unique_candidates = find_candidate_datasets(dataset_id) if len(unique_candidates) > 1: raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the " "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results)) if len(unique_candidates) == 0: raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID " f"exists and that nnU-Net knows where raw and preprocessed data are located " f"(see Documentation - Installation). Here are your currently defined folders:\n" f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n" f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n" f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n" f"If something is not right, adapt your environment variables.") return unique_candidates[0] def convert_dataset_name_to_id(dataset_name: str): assert dataset_name.startswith("Dataset") dataset_id = int(dataset_name[7:10]) return dataset_id def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str: if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"): return dataset_name_or_id if isinstance(dataset_name_or_id, str): try: dataset_name_or_id = int(dataset_name_or_id) except ValueError: raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to " "convert it to a dataset ID (int). That failed, however. Please give an integer number " "('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id) return convert_id_to_dataset_name(dataset_name_or_id) ================================================ FILE: nnunetv2/utilities/ddp_allgather.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional, Tuple import torch from torch import distributed def print_if_rank0(*args): if distributed.get_rank() == 0: print(*args) class AllGatherGrad(torch.autograd.Function): # stolen from pytorch lightning @staticmethod def forward( ctx: Any, tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, ) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(gathered_tensor, tensor, group=group) gathered_tensor = torch.stack(gathered_tensor, dim=0) return gathered_tensor @staticmethod def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: grad_output = torch.cat(grad_output) torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) return grad_output[torch.distributed.get_rank()], None ================================================ FILE: nnunetv2/utilities/default_n_proc_DA.py ================================================ import subprocess import os def get_allowed_n_proc_DA(): """ This function is used to set the number of processes used on different Systems. It is specific to our cluster infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed. IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script (see first line). Interpret the output as the number of processes used for data augmentation PER GPU. The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use. """ if 'nnUNet_n_proc_DA' in os.environ.keys(): use_this = int(os.environ['nnUNet_n_proc_DA']) else: hostname = subprocess.getoutput(['hostname']) if hostname in ['Fabian', 'isensee-']: use_this = 12 elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']: use_this = 16 elif hostname.startswith('e230-dgx1'): use_this = 10 elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'): use_this = 16 elif hostname.startswith('e230-dgx2'): use_this = 6 elif hostname.startswith('e230-dgxa100-'): use_this = 28 elif hostname.startswith('e230-thinka100-'): use_this = 20 elif hostname.startswith('lsf22-gpu'): use_this = 28 elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'): use_this = 12 elif 'superh200' in hostname.lower(): use_this = 48 elif 'superl40s' in hostname.lower(): use_this = 28 else: use_this = 12 # default value use_this = min(use_this, os.cpu_count()) return use_this ================================================ FILE: nnunetv2/utilities/file_path_utilities.py ================================================ from multiprocessing import Pool from typing import Union, Tuple import numpy as np from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.configuration import default_num_processes from nnunetv2.paths import nnUNet_results from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration): return f'{trainer_name}__{plans_identifier}__{configuration}' def convert_identifier_to_trainer_plans_config(identifier: str): return os.path.basename(identifier).split('__') def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer', plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres', fold: Union[str, int] = None) -> str: tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id), convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration)) if fold is not None: tmp = join(tmp, f'fold_{fold}') return tmp def parse_dataset_trainer_plans_configuration_from_path(path: str): folders = split_path(path) # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol # safer to make this depend on two conditions, the fold_x and the DatasetXXX # first let's see if some fold_X is present fold_x_present = [i.startswith('fold_') for i in folders] if any(fold_x_present): idx = fold_x_present.index(True) # OK now two entries before that there should be DatasetXXX assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' if folders[idx - 2].startswith('Dataset'): split = folders[idx - 1].split('__') assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' return folders[idx - 2], *split else: # we can only check for dataset followed by a string that is separable into three strings by splitting with '__' # look for DatasetXXX dataset_folder = [i.startswith('Dataset') for i in folders] if any(dataset_folder): idx = dataset_folder.index(True) assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' split = folders[idx + 1].split('__') assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' return folders[idx], *split def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]): identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \ os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds) return identifier def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]): model1_folder = get_output_folder(dataset, tr1, p1, c1) model2_folder = get_output_folder(dataset, tr2, p2, c2) get_ensemble_name(model1_folder, model2_folder, folds) def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str): prefix, *models, folds = os.path.basename(ensemble_folder).split('___') return models, folds def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]): s = str(folds[0]) for f in folds[1:]: s += f"_{f}" return s def folds_string_to_tuple(folds_string: str): folds = folds_string.split('_') res = [] for f in folds: try: res.append(int(f)) except ValueError: res.append(f) return res def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0): """ returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued """ alive = [i.is_alive() for i in worker_list] if not all(alive): raise RuntimeError('Some background workers are no longer alive') not_ready = [not i.ready() for i in results_list] if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued): return True return False if __name__ == '__main__': ### well at this point I could just write tests... path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' print(parse_dataset_trainer_plans_configuration_from_path(path)) path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' print(parse_dataset_trainer_plans_configuration_from_path(path)) path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all' print(parse_dataset_trainer_plans_configuration_from_path(path)) try: path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/' print(parse_dataset_trainer_plans_configuration_from_path(path)) except AssertionError: print('yayy, assertion works') ================================================ FILE: nnunetv2/utilities/find_class_by_name.py ================================================ import importlib import pkgutil from batchgenerators.utilities.file_and_folder_operations import * def recursive_find_python_class(folder: str, class_name: str, current_module: str): tr = None for importer, modname, ispkg in pkgutil.iter_modules([folder]): # print(modname, ispkg) if not ispkg: m = importlib.import_module(current_module + "." + modname) if hasattr(m, class_name): tr = getattr(m, class_name) break if tr is None: for importer, modname, ispkg in pkgutil.iter_modules([folder]): if ispkg: next_current_module = current_module + "." + modname tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) if tr is not None: break return tr ================================================ FILE: nnunetv2/utilities/get_network_from_plans.py ================================================ import pydoc import warnings from typing import Union from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from batchgenerators.utilities.file_and_folder_operations import join def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels, allow_init=True, deep_supervision: Union[bool, None] = None): network_class = arch_class_name architecture_kwargs = dict(**arch_kwargs) for ri in arch_kwargs_req_import: if architecture_kwargs[ri] is not None: architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri]) nw_class = pydoc.locate(network_class) # sometimes things move around, this makes it so that we can at least recover some of that if nw_class is None: warnings.warn(f'Network class {network_class} not found. Attempting to locate it within ' f'dynamic_network_architectures.architectures...') import dynamic_network_architectures nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), network_class.split(".")[-1], 'dynamic_network_architectures.architectures') if nw_class is not None: print(f'FOUND IT: {nw_class}') else: raise ImportError('Network class could not be found, please check/correct your plans file') if deep_supervision is not None: architecture_kwargs['deep_supervision'] = deep_supervision network = nw_class( input_channels=input_channels, num_classes=output_channels, **architecture_kwargs ) if hasattr(network, 'initialize') and allow_init: network.apply(network.initialize) return network if __name__ == "__main__": import torch model = get_network_from_plans( arch_class_name="dynamic_network_architectures.architectures.unet.ResidualEncoderUNet", arch_kwargs={ "n_stages": 7, "features_per_stage": [32, 64, 128, 256, 512, 512, 512], "conv_op": "torch.nn.modules.conv.Conv2d", "kernel_sizes": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], "strides": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], "n_blocks_per_stage": [1, 3, 4, 6, 6, 6, 6], "n_conv_per_stage_decoder": [1, 1, 1, 1, 1, 1], "conv_bias": True, "norm_op": "torch.nn.modules.instancenorm.InstanceNorm2d", "norm_op_kwargs": {"eps": 1e-05, "affine": True}, "dropout_op": None, "dropout_op_kwargs": None, "nonlin": "torch.nn.LeakyReLU", "nonlin_kwargs": {"inplace": True}, }, arch_kwargs_req_import=["conv_op", "norm_op", "dropout_op", "nonlin"], input_channels=1, output_channels=4, allow_init=True, deep_supervision=True, ) data = torch.rand((8, 1, 256, 256)) target = torch.rand(size=(8, 1, 256, 256)) outputs = model(data) # this should be a list of torch.Tensor ================================================ FILE: nnunetv2/utilities/helpers.py ================================================ import torch def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor: return torch.softmax(x, 0) def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: return torch.softmax(x, 1) def empty_cache(device: torch.device): if device.type == 'cuda': torch.cuda.empty_cache() elif device.type == 'mps': from torch import mps mps.empty_cache() else: pass class dummy_context(object): def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass ================================================ FILE: nnunetv2/utilities/json_export.py ================================================ from collections.abc import Iterable import numpy as np import torch def recursive_fix_for_json_export(my_dict: dict): # json is ... a very nice thing to have # 'cannot serialize object of type bool_/int64/float64'. Apart from that of course... keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... for k in keys: if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): tmp = my_dict[k] del my_dict[k] my_dict[int(k)] = tmp del tmp k = int(k) if isinstance(my_dict[k], dict): recursive_fix_for_json_export(my_dict[k]) elif isinstance(my_dict[k], np.ndarray): assert my_dict[k].ndim == 1, 'only 1d arrays are supported' my_dict[k] = fix_types_iterable(my_dict[k], output_type=list) elif isinstance(my_dict[k], (np.bool_,)): my_dict[k] = bool(my_dict[k]) elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)): my_dict[k] = int(my_dict[k]) elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)): my_dict[k] = float(my_dict[k]) elif isinstance(my_dict[k], list): my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k])) elif isinstance(my_dict[k], tuple): my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple) elif isinstance(my_dict[k], torch.device): my_dict[k] = str(my_dict[k]) else: pass # pray it can be serialized def fix_types_iterable(iterable, output_type): # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this. out = [] for i in iterable: if type(i) in (np.int64, np.int32, np.int8, np.uint8): out.append(int(i)) elif isinstance(i, dict): recursive_fix_for_json_export(i) out.append(i) elif type(i) in (np.float32, np.float64, np.float16): out.append(float(i)) elif type(i) in (np.bool_,): out.append(bool(i)) elif isinstance(i, str): out.append(i) elif isinstance(i, Iterable): # print('recursive call on', i, type(i)) out.append(fix_types_iterable(i, type(i))) else: out.append(i) return output_type(out) ================================================ FILE: nnunetv2/utilities/label_handling/__init__.py ================================================ ================================================ FILE: nnunetv2/utilities/label_handling/label_handling.py ================================================ from __future__ import annotations from time import time from typing import Union, List, Tuple, Type import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, insert_crop_into_image from batchgenerators.utilities.file_and_folder_operations import join import nnunetv2 from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.helpers import softmax_helper_dim0 from typing import TYPE_CHECKING # see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ if TYPE_CHECKING: from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager class LabelManager(object): def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False, inference_nonlin=None): self._sanity_check(label_dict) self.label_dict = label_dict self.regions_class_order = regions_class_order self._force_use_labels = force_use_labels if force_use_labels: self._has_regions = False else: self._has_regions: bool = any( [isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()]) self._ignore_label: Union[None, int] = self._determine_ignore_label() self._all_labels: List[int] = self._get_all_labels() self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions() if self.has_ignore_label: assert self.ignore_label == max( self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \ 'label value! It cannot be 0 or in between other labels. ' \ 'Sorry bro.' if inference_nonlin is None: self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0 else: self.inference_nonlin = inference_nonlin def _sanity_check(self, label_dict: dict): if not 'background' in label_dict.keys(): raise RuntimeError('Background label not declared (remember that this should be label 0!)') bg_label = label_dict['background'] if isinstance(bg_label, (tuple, list)): raise RuntimeError(f"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}") assert int(bg_label) == 0, f"Background label must be 0. Your background label: {bg_label}" # not sure if we want to allow regions that contain background. I don't immediately see how this could cause # problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this # just crashes. def _get_all_labels(self) -> List[int]: all_labels = [] for k, r in self.label_dict.items(): # ignore label is not going to be used, hence the name. Duh. if k == 'ignore': continue if isinstance(r, (tuple, list)): for ri in r: all_labels.append(int(ri)) else: all_labels.append(int(r)) all_labels = list(np.unique(all_labels)) all_labels.sort() return all_labels def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: if not self._has_regions or self._force_use_labels: return None else: assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ 'define regions_class_order!' regions = [] for k, r in self.label_dict.items(): # ignore ignore label if k == 'ignore': continue # ignore regions that are background if (np.isscalar(r) and r == 0) \ or \ (isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0): continue if isinstance(r, list): r = tuple(r) regions.append(r) assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \ 'many entries as there are ' \ 'regions' return regions def _determine_ignore_label(self) -> Union[None, int]: ignore_label = self.label_dict.get('ignore') if ignore_label is not None: assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \ f'(list/tuple). Got {type(ignore_label)}.' return ignore_label @property def has_regions(self) -> bool: return self._has_regions @property def has_ignore_label(self) -> bool: return self.ignore_label is not None @property def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: return self._regions @property def all_labels(self) -> List[int]: return self._all_labels @property def ignore_label(self) -> Union[None, int]: return self._ignore_label def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \ Union[np.ndarray, torch.Tensor]: """ logits has to have shape (c, x, y(, z)) where c is the number of classes/regions """ if isinstance(logits, np.ndarray): logits = torch.from_numpy(logits) with torch.no_grad(): # softmax etc is not implemented for half logits = logits.float() probabilities = self.inference_nonlin(logits) return probabilities @torch.inference_mode() def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \ Union[np.ndarray, torch.Tensor]: """ assumes that inference_nonlinearity was already applied! predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions """ if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)): raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor," f" got {type(predicted_probabilities)}") if self.has_regions: assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ 'define regions_class_order!' # check correct number of outputs assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \ f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \ f'got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape ' \ f'(c, x, y(, z)).' if self.has_regions: if isinstance(predicted_probabilities, np.ndarray): segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16) else: # no uint16 in torch segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16, device=predicted_probabilities.device) for i, c in enumerate(self.regions_class_order): segmentation[predicted_probabilities[i] > 0.5] = c else: # numpy is faster than torch. :facepalm: is_numpy = isinstance(predicted_probabilities, np.ndarray) if not is_numpy: predicted_probabilities = predicted_probabilities.numpy() segmentation = predicted_probabilities.argmax(0) if not is_numpy: segmentation = torch.from_numpy(segmentation) return segmentation @torch.inference_mode() def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \ Union[np.ndarray, torch.Tensor]: input_is_numpy = isinstance(predicted_logits, np.ndarray) # we can skip this step if we do not have region. Argmax is the same between logits or probabilities if self.has_regions: probabilities = self.apply_inference_nonlin(predicted_logits) else: probabilities = predicted_logits if input_is_numpy and isinstance(probabilities, torch.Tensor): probabilities = probabilities.cpu().numpy() return self.convert_probabilities_to_segmentation(probabilities) def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray], bbox: List[List[int]], original_shape: Union[List[int], Tuple[int, ...]]): """ ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!! predicted_probabilities must be (c, x, y(, z)) Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities and not have strange artifacts. Only LabelManager knows how this needs to be done. So let's let him/her do it, ok? """ # revert cropping probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype) \ if isinstance(predicted_probabilities, np.ndarray) else \ torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype) if not self.has_regions: probs_reverted_cropping[0] = 1 probs_reverted_cropping = insert_crop_into_image(probs_reverted_cropping, predicted_probabilities, bbox) return probs_reverted_cropping @staticmethod def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]): # heck yeah # This is definitely taking list comprehension too far. Enjoy. return [i for i in classes_or_regions if ((not isinstance(i, (tuple, list))) and i != 0) or (isinstance(i, (tuple, list)) and not ( len(np.unique(i)) == 1 and np.unique(i)[0] == 0))] @property def foreground_regions(self): return self.filter_background(self.all_regions) @property def foreground_labels(self): return self.filter_background(self.all_labels) @property def num_segmentation_heads(self): if self.has_regions: return len(self.foreground_regions) else: return len(self.all_labels) def get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]: if 'label_manager' not in plans.keys(): print('No label manager specified in plans. Using default: LabelManager') return LabelManager else: labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], "utilities", "label_handling"), plans['label_manager'], current_module="nnunetv2.utilities.label_handling") return labelmanager_class def convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor], all_labels: Union[List, torch.Tensor, np.ndarray, tuple], output_dtype=None) -> Union[np.ndarray, torch.Tensor]: """ if output_dtype is None then we use np.uint8/torch.uint8 if input is torch.Tensor then output will be on the same device np.ndarray is faster than torch.Tensor if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have to cast which takes time. IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ... DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo) """ if isinstance(segmentation, torch.Tensor): result = torch.zeros((len(all_labels), *segmentation.shape), dtype=output_dtype if output_dtype is not None else (torch.uint8 if max(all_labels) < 255 else torch.uint16), device=segmentation.device) # variant 1, 2x faster than 2 result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!? # variant 2, slower than 1 # for i, l in enumerate(all_labels): # result[i] = segmentation == l else: result = np.zeros((len(all_labels), *segmentation.shape), dtype=output_dtype if output_dtype is not None else (np.uint8 if max(all_labels) < 255 else np.uint16)) # variant 1, fastest in my testing for i, l in enumerate(all_labels): result[i] = segmentation == l # variant 2. Takes about twice as long so nah # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2)) return result def determine_num_input_channels(plans_manager: PlansManager, configuration_or_config_manager: Union[str, ConfigurationManager], dataset_json: dict) -> int: if isinstance(configuration_or_config_manager, str): config_manager = plans_manager.get_configuration(configuration_or_config_manager) else: config_manager = configuration_or_config_manager label_manager = plans_manager.get_label_manager(dataset_json) num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names']) # cascade has different number of input channels if config_manager.previous_stage_name is not None: num_label_inputs = len(label_manager.foreground_labels) num_input_channels = num_modalities + num_label_inputs else: num_input_channels = num_modalities return num_input_channels if __name__ == '__main__': # this code used to be able to differentiate variant 1 and 2 to measure time. num_labels = 7 seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8) seg_torch = torch.from_numpy(seg) st = time() onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) time_1 = time() onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) time_2 = time() onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) time_torch = time() onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) time_torch2 = time() print( f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}') onehot_torch = onehot_torch.numpy() onehot_torch2 = onehot_torch2.numpy() print(np.all(onehot_torch == onehot_npy)) print(np.all(onehot_torch2 == onehot_npy)) ================================================ FILE: nnunetv2/utilities/network_initialization.py ================================================ from torch import nn class InitWeights_He(object): def __init__(self, neg_slope=1e-2): self.neg_slope = neg_slope def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) ================================================ FILE: nnunetv2/utilities/overlay_plots.py ================================================ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing from multiprocessing.pool import Pool from typing import Tuple, Union import numpy as np import pandas as pd from batchgenerators.utilities.file_and_folder_operations import * from nnunetv2.configuration import default_num_processes from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed from nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class, nnUNetBaseDataset from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ get_filenames_of_train_images_and_targets color_cycle = ( "000000", "4363d8", "f58231", "3cb44b", "e6194B", "911eb4", "ffe119", "bfef45", "42d4f4", "f032e6", "000075", "9A6324", "808000", "800000", "469990", ) def hex_to_rgb(hex: str): assert len(hex) == 6 return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4)) def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, color_cycle: Tuple[str, ...] = color_cycle, overlay_intensity: float = 0.6): """ image can be 2d greyscale or 2d RGB (color channel in last dimension!) Segmentation must be label map of same shape as image (w/o color channels) mapping can be label_id -> idx_in_cycle or None returned image is scaled to [0, 255] (uint8)!!! """ # create a copy of image image = np.copy(input_image) if image.ndim == 2: image = np.tile(image[:, :, None], (1, 1, 3)) elif image.ndim == 3: if image.shape[2] == 1: image = np.tile(image, (1, 1, 3)) else: raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). ' f'Only 2D images are supported. Your image shape: {image.shape}') else: raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in " "last dimension) are supported") # rescale image to [0, 255] image = image - image.min() image = image / image.max() * 255 # create output if mapping is None: uniques = np.sort(pd.unique(segmentation.ravel())) # np.unique(segmentation) mapping = {i: c for c, i in enumerate(uniques)} for l in mapping.keys(): image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]])) # rescale result to [0, 255] image = image / image.max() * 255 return image.astype(np.uint8) def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int: """ image and segmentation are expected to be 3D selects the slice with the largest amount of fg (regardless of label) we give image so that we can easily replace this function if needed """ fg_mask = segmentation != 0 fg_per_slice = fg_mask.sum((1, 2)) selected_slice = int(np.argmax(fg_per_slice)) return selected_slice def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int: """ image and segmentation are expected to be 3D (or 1, x, y) selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice with highest avg percent) we give image so that we can easily replace this function if needed """ classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i > 0] fg_per_slice = np.zeros((image.shape[0], len(classes))) for i, c in enumerate(classes): fg_mask = segmentation == c fg_per_slice[:, i] = fg_mask.sum((1, 2)) fg_per_slice[:, i] /= fg_per_slice.sum() fg_per_slice = fg_per_slice.mean(1) return int(np.argmax(fg_per_slice)) def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str, overlay_intensity: float = 0.6): import matplotlib.pyplot as plt image, props = image_reader_writer.read_images((image_file, )) image = image[0] seg, props_seg = image_reader_writer.read_seg(segmentation_file) seg = seg[0] assert image.shape == seg.shape, "image and seg do not have the same shape: %s, %s" % ( image_file, segmentation_file) assert image.ndim == 3, 'only 3D images/segs are supported' selected_slice = select_slice_to_plot2(image, seg) # print(image.shape, selected_slice) overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) plt.imsave(output_file, overlay) def plot_overlay_preprocessed(dataset: nnUNetBaseDataset, k: str, output_folder: str, overlay_intensity: float = 0.6, channel_idx=0): import matplotlib.pyplot as plt data, seg, _, properties = dataset.load_case(k) assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1) image = data[channel_idx] seg = seg[0] selected_slice = select_slice_to_plot2(image, seg) seg = np.copy(seg[selected_slice]) seg[seg < 0] = 0 overlay = generate_overlay(image[selected_slice], seg, overlay_intensity=overlay_intensity) plt.imsave(join(output_folder, k + '.png'), overlay) def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer, list_of_output_files, overlay_intensity, num_processes=8): with multiprocessing.get_context("spawn").Pool(num_processes) as p: r = p.starmap_async(plot_overlay, zip( list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files), list_of_output_files, [overlay_intensity] * len(list_of_output_files) )) r.get() def multiprocessing_plot_overlay_preprocessed(dataset: nnUNetBaseDataset, output_folder, overlay_intensity, num_processes=8, channel_idx=0): with multiprocessing.get_context("spawn").Pool(num_processes) as p: r = [] for k in dataset.identifiers: r.append( p.starmap_async(plot_overlay_preprocessed, (( dataset, k, output_folder, overlay_intensity, channel_idx ),)) ) _ = [i.get() for i in r] def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str, num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6): dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) folder = join(nnUNet_raw, dataset_name) dataset_json = load_json(join(folder, 'dataset.json')) dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) image_files = [v['images'][channel_idx] for v in dataset.values()] seg_files = [v['label'] for v in dataset.values()] assert all([isfile(i) for i in image_files]) assert all([isfile(i) for i in seg_files]) maybe_mkdir_p(output_folder) output_files = [join(output_folder, i + '.png') for i in dataset.keys()] image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])() multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes) def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str, num_processes: int = 8, channel_idx: int = 0, configuration: str = None, plans_identifier: str = 'nnUNetPlans', overlay_intensity: float = 0.6): dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) folder = join(nnUNet_preprocessed, dataset_name) if not isdir(folder): raise RuntimeError("run preprocessing for that task first") plans = load_json(join(folder, plans_identifier + '.json')) if configuration is None: if '3d_fullres' in plans['configurations'].keys(): configuration = '3d_fullres' else: configuration = '2d' cm = ConfigurationManager(plans['configurations'][configuration]) preprocessed_folder = join(folder, cm.data_identifier) if not isdir(preprocessed_folder): raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier " f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this " f"configuration first!") dc = infer_dataset_class(preprocessed_folder) dataset = dc(preprocessed_folder) maybe_mkdir_p(output_folder) multiprocessing_plot_overlay_preprocessed(dataset, output_folder, overlay_intensity=overlay_intensity, num_processes=num_processes, channel_idx=channel_idx) def entry_point_generate_overlay(): import argparse parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this " "disregards spacing information!") parser.add_argument('-d', type=str, help="Dataset name or id", required=True) parser.add_argument('-o', type=str, help="output folder", required=True) parser.add_argument('-np', type=int, default=default_num_processes, required=False, help=f"number of processes used. Default: {default_num_processes}") parser.add_argument('-channel_idx', type=int, default=0, required=False, help="channel index used (0 = _0000). Default: 0") parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else " "we use preprocessed") parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans') parser.add_argument('-c', type=str, required=False, default=None, help='configuration name. Only used if --use_raw is not set! Default: None = ' '3d_fullres if available, else 2d') parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6, help='overlay intensity. Higher = brighter/less transparent') args = parser.parse_args() if args.use_raw: generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx, overlay_intensity=args.overlay_intensity) else: generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p, overlay_intensity=args.overlay_intensity) if __name__ == '__main__': entry_point_generate_overlay() ================================================ FILE: nnunetv2/utilities/plans_handling/__init__.py ================================================ ================================================ FILE: nnunetv2/utilities/plans_handling/plans_handler.py ================================================ from __future__ import annotations import warnings from copy import deepcopy from functools import lru_cache, partial from typing import Union, Tuple, List, Type, Callable import numpy as np import torch from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name import nnunetv2 from batchgenerators.utilities.file_and_folder_operations import load_json, join from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans # see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ from typing import TYPE_CHECKING from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm if TYPE_CHECKING: from nnunetv2.utilities.label_handling.label_handling import LabelManager from nnunetv2.imageio.base_reader_writer import BaseReaderWriter from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner class ConfigurationManager(object): def __init__(self, configuration_dict: dict): self.configuration = configuration_dict # backwards compatibility if 'architecture' not in self.configuration.keys(): warnings.warn("Detected old nnU-Net plans format. Attempting to reconstruct network architecture " "parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a " "custom architecture, please downgrade nnU-Net to the version you implemented this " "or update your implementation + plans.") # try to build the architecture information from old plans, modify configuration dict to match new standard unet_class_name = self.configuration["UNet_class_name"] if unet_class_name == "PlainConvUNet": network_class_name = "dynamic_network_architectures.architectures.unet.PlainConvUNet" elif unet_class_name == 'ResidualEncoderUNet': network_class_name = "dynamic_network_architectures.architectures.residual_unet.ResidualEncoderUNet" else: raise RuntimeError(f'Unknown architecture {unet_class_name}. This conversion only supports ' f'PlainConvUNet and ResidualEncoderUNet') n_stages = len(self.configuration["n_conv_per_stage_encoder"]) dim = len(self.configuration["patch_size"]) conv_op = convert_dim_to_conv_op(dim) instnorm = get_matching_instancenorm(dimension=dim) convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage" arch_dict = { 'network_class_name': network_class_name, 'arch_kwargs': { "n_stages": n_stages, "features_per_stage": [min(self.configuration["UNet_base_num_features"] * 2 ** i, self.configuration["unet_max_num_features"]) for i in range(n_stages)], "conv_op": conv_op.__module__ + '.' + conv_op.__name__, "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]), "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), "conv_bias": True, "norm_op": instnorm.__module__ + '.' + instnorm.__name__, "norm_op_kwargs": { "eps": 1e-05, "affine": True }, "dropout_op": None, "dropout_op_kwargs": None, "nonlin": "torch.nn.LeakyReLU", "nonlin_kwargs": { "inplace": True } }, # these need to be imported with locate in order to use them: # `conv_op = pydoc.locate(architecture_kwargs['conv_op'])` "_kw_requires_import": [ "conv_op", "norm_op", "dropout_op", "nonlin" ] } del self.configuration["UNet_class_name"], self.configuration["UNet_base_num_features"], \ self.configuration["n_conv_per_stage_encoder"], self.configuration["n_conv_per_stage_decoder"], \ self.configuration["num_pool_per_axis"], self.configuration["pool_op_kernel_sizes"],\ self.configuration["conv_kernel_sizes"], self.configuration["unet_max_num_features"] self.configuration["architecture"] = arch_dict def __repr__(self): return self.configuration.__repr__() @property def data_identifier(self) -> str: return self.configuration['data_identifier'] @property def preprocessor_name(self) -> str: return self.configuration['preprocessor_name'] @property @lru_cache(maxsize=1) def preprocessor_class(self) -> Type[DefaultPreprocessor]: preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"), self.preprocessor_name, current_module="nnunetv2.preprocessing") return preprocessor_class @property def batch_size(self) -> int: return self.configuration['batch_size'] @property def patch_size(self) -> List[int]: return self.configuration['patch_size'] @property def median_image_size_in_voxels(self) -> List[int]: return self.configuration['median_image_size_in_voxels'] @property def spacing(self) -> List[float]: return self.configuration['spacing'] @property def normalization_schemes(self) -> List[str]: return self.configuration['normalization_schemes'] @property def use_mask_for_norm(self) -> List[bool]: return self.configuration['use_mask_for_norm'] @property def network_arch_class_name(self) -> str: return self.configuration['architecture']['network_class_name'] @property def network_arch_init_kwargs(self) -> dict: return self.configuration['architecture']['arch_kwargs'] @property def network_arch_init_kwargs_req_import(self) -> Union[Tuple[str, ...], List[str]]: return self.configuration['architecture']['_kw_requires_import'] @property def pool_op_kernel_sizes(self) -> Tuple[Tuple[int, ...], ...]: return self.configuration['architecture']['arch_kwargs']['strides'] @property @lru_cache(maxsize=1) def resampling_fn_data(self) -> Callable[ [Union[torch.Tensor, np.ndarray], Union[Tuple[int, ...], List[int], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray] ], Union[torch.Tensor, np.ndarray]]: fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data']) fn = partial(fn, **self.configuration['resampling_fn_data_kwargs']) return fn @property @lru_cache(maxsize=1) def resampling_fn_probabilities(self) -> Callable[ [Union[torch.Tensor, np.ndarray], Union[Tuple[int, ...], List[int], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray] ], Union[torch.Tensor, np.ndarray]]: fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities']) fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs']) return fn @property @lru_cache(maxsize=1) def resampling_fn_seg(self) -> Callable[ [Union[torch.Tensor, np.ndarray], Union[Tuple[int, ...], List[int], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray], Union[Tuple[float, ...], List[float], np.ndarray] ], Union[torch.Tensor, np.ndarray]]: fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg']) fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs']) return fn @property def batch_dice(self) -> bool: return self.configuration['batch_dice'] @property def next_stage_names(self) -> Union[List[str], None]: ret = self.configuration.get('next_stage') if ret is not None: if isinstance(ret, str): ret = [ret] return ret @property def previous_stage_name(self) -> Union[str, None]: return self.configuration.get('previous_stage') class PlansManager(object): def __init__(self, plans_file_or_dict: Union[str, dict]): """ Why do we need this? 1) resolve inheritance in configurations 2) expose otherwise annoying stuff like getting the label manager or IO class from a string 3) clearly expose the things that are in the plans instead of hiding them in a dict 4) cache shit This class does not prevent you from going wild. You can still use the plans directly if you prefer (PlansHandler.plans['key']) """ self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict) def __repr__(self): return self.plans.__repr__() def _internal_resolve_configuration_inheritance(self, configuration_name: str, visited: Tuple[str, ...] = None) -> dict: if configuration_name not in self.plans['configurations'].keys(): raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid ' f'configuration names are {list(self.plans["configurations"].keys())}.') configuration = deepcopy(self.plans['configurations'][configuration_name]) if 'inherits_from' in configuration: parent_config_name = configuration['inherits_from'] if visited is None: visited = (configuration_name,) else: if parent_config_name in visited: raise RuntimeError(f"Circular dependency detected. The following configurations were visited " f"while solving inheritance (in that order!): {visited}. " f"Current configuration: {configuration_name}. Its parent configuration " f"is {parent_config_name}.") visited = (*visited, configuration_name) base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited) base_config.update(configuration) configuration = base_config return configuration @lru_cache(maxsize=10) def get_configuration(self, configuration_name: str): if configuration_name not in self.plans['configurations'].keys(): raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. " f"Available configurations: {list(self.plans['configurations'].keys())}") configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name) return ConfigurationManager(configuration_dict) @property def dataset_name(self) -> str: return self.plans['dataset_name'] @property def plans_name(self) -> str: return self.plans['plans_name'] @property def original_median_spacing_after_transp(self) -> List[float]: return self.plans['original_median_spacing_after_transp'] @property def original_median_shape_after_transp(self) -> List[float]: return self.plans['original_median_shape_after_transp'] @property @lru_cache(maxsize=1) def image_reader_writer_class(self) -> Type[BaseReaderWriter]: return recursive_find_reader_writer_by_name(self.plans['image_reader_writer']) @property def transpose_forward(self) -> List[int]: return self.plans['transpose_forward'] @property def transpose_backward(self) -> List[int]: return self.plans['transpose_backward'] @property def available_configurations(self) -> List[str]: return list(self.plans['configurations'].keys()) @property @lru_cache(maxsize=1) def experiment_planner_class(self) -> Type[ExperimentPlanner]: planner_name = self.experiment_planner_name experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), planner_name, current_module="nnunetv2.experiment_planning") return experiment_planner @property def experiment_planner_name(self) -> str: return self.plans['experiment_planner_used'] @property @lru_cache(maxsize=1) def label_manager_class(self) -> Type[LabelManager]: return get_labelmanager_class_from_plans(self.plans) def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager: return self.label_manager_class(label_dict=dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order'), **kwargs) @property def foreground_intensity_properties_per_channel(self) -> dict: if 'foreground_intensity_properties_per_channel' not in self.plans.keys(): if 'foreground_intensity_properties_by_modality' in self.plans.keys(): return self.plans['foreground_intensity_properties_by_modality'] return self.plans['foreground_intensity_properties_per_channel'] if __name__ == '__main__': from nnunetv2.paths import nnUNet_preprocessed from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json')) # build new configuration that inherits from 3d_fullres plans['configurations']['3d_fullres_bs4'] = { 'batch_size': 4, 'inherits_from': '3d_fullres' } # now get plans and configuration managers plans_manager = PlansManager(plans) configuration_manager = plans_manager.get_configuration('3d_fullres_bs4') print(configuration_manager) # look for batch size 4 ================================================ FILE: nnunetv2/utilities/utils.py ================================================ # Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center # (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os.path from functools import lru_cache from typing import Union from batchgenerators.utilities.file_and_folder_operations import * import numpy as np import re from nnunetv2.paths import nnUNet_raw from multiprocessing import Pool def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str): files = subfiles(folder, suffix=file_ending, join=False) # all files have a 4 digit channel index (_XXXX) crop = len(file_ending) + 5 files = [i[:-crop] for i in files] # only unique image ids files = np.unique(files) return files def create_paths_fn(folder, files, file_ending, f): p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending)) return [join(folder, i) for i in files if p.fullmatch(i)] def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None, num_processes: int = 12) -> List[ List[str]]: """ does not rely on dataset.json """ if identifiers is None: identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) files = subfiles(folder, suffix=file_ending, join=False, sort=True) list_of_lists = [] params_list = [(folder, files, file_ending, f) for f in identifiers] with Pool(processes=num_processes) as pool: list_of_lists = pool.starmap(create_paths_fn, params_list) return list_of_lists def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None): if dataset_json is None: dataset_json = load_json(join(raw_dataset_folder, 'dataset.json')) if 'dataset' in dataset_json.keys(): dataset = dataset_json['dataset'] for k in dataset.keys(): expanded_label_file = os.path.expandvars(dataset[k]['label']) dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, expanded_label_file)) if not os.path.isabs(expanded_label_file) else expanded_label_file dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, os.path.expandvars(i))) if not os.path.isabs(os.path.expandvars(i)) else os.path.expandvars(i) for i in dataset[k]['images']] else: identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending']) images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers) segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers] dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)} return dataset if __name__ == '__main__': print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart'))) ================================================ FILE: pyproject.toml ================================================ [project] name = "nnunetv2" version = "2.6.4" requires-python = ">=3.10" description = "nnU-Net is a framework for out-of-the box image segmentation." readme = "readme.md" license = { file = "LICENSE" } authors = [ { name = "Fabian Isensee", email = "f.isensee@dkfz-heidelberg.de"}, { name = "Helmholtz Imaging Applied Computer Vision Lab" } ] classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Intended Audience :: Healthcare Industry", "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Image Recognition", "Topic :: Scientific/Engineering :: Medical Science Apps.", ] keywords = [ 'deep learning', 'image segmentation', 'semantic segmentation', 'medical image analysis', 'medical image segmentation', 'nnU-Net', 'nnunet' ] dependencies = [ "torch>=2.1.2,!=2.9.*", "acvl-utils>=0.2.3,<0.3", # 0.3 may bring breaking changes. Careful! "dynamic-network-architectures>=0.4.1,<0.5", "tqdm", "scipy", "batchgenerators>=0.25.1", "numpy>=1.24", "scikit-learn", "scikit-image>=0.19.3", "SimpleITK>=2.2.1", "pandas", "graphviz", 'tifffile', 'requests', "nibabel", "matplotlib", "seaborn", "imagecodecs", "yacs", "batchgeneratorsv2>=0.3.0", "einops", "blosc2>=3.0.0b1" ] [project.urls] homepage = "https://github.com/MIC-DKFZ/nnUNet" repository = "https://github.com/MIC-DKFZ/nnUNet" [project.scripts] nnUNetv2_plan_and_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry" nnUNetv2_extract_fingerprint = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry" nnUNetv2_plan_experiment = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry" nnUNetv2_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry" nnUNetv2_train = "nnunetv2.run.run_training:run_training_entry" nnUNetv2_predict_from_modelfolder = "nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder" nnUNetv2_predict = "nnunetv2.inference.predict_from_raw_data:predict_entry_point" nnUNetv2_convert_old_nnUNet_dataset = "nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point" nnUNetv2_find_best_configuration = "nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point" nnUNetv2_determine_postprocessing = "nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder" nnUNetv2_apply_postprocessing = "nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing" nnUNetv2_ensemble = "nnunetv2.ensembling.ensemble:entry_point_ensemble_folders" nnUNetv2_accumulate_crossval_results = "nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point" nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_generate_overlay" nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url" nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point" nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry" nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets" nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point" nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point" nnUNetv2_convert_MSD_dataset = "nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point" [project.optional-dependencies] dev = [ "black", "ruff", "pre-commit" ] [build-system] requires = ["setuptools>=67.8.0"] build-backend = "setuptools.build_meta" [tool.codespell] skip = '.git,*.pdf,*.svg' # # ignore-words-list = '' ================================================ FILE: readme.md ================================================ # Welcome to the new nnU-Net! Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead. Coming from V1? Check out the [TLDR Migration Guide](documentation/tldr_migration_guide_from_v1.md). Reading the rest of the documentation is still strongly recommended ;-) ## **2025-10-23 There seems to be a [severe performance regression with torch 2.9.0 and 3D convs](https://github.com/pytorch/pytorch/issues/166122) (when using AMP). Please use torch 2.8.0 or lower with nnU-Net!** ## **2024-04-18 UPDATE: New residual encoder UNet presets available!** Residual encoder UNet presets substantially improve segmentation performance. They ship for a variety of GPU memory targets. It's all awesome stuff, promised! Read more :point_right: [here](documentation/resenc_presets.md) :point_left: Also check out our [new paper](https://arxiv.org/pdf/2404.09556.pdf) on systematically benchmarking recent developments in medical image segmentation. You might be surprised! # What is nnU-Net? Image datasets are enormously diverse: image dimensionality (2D, 3D), modalities/input channels (RGB image, CT, MRI, microscopy, ...), image sizes, voxel sizes, class ratio, target structure properties and more change substantially between datasets. Traditionally, given a new problem, a tailored solution needs to be manually designed and optimized - a process that is prone to errors, not scalable and where success is overwhelmingly determined by the skill of the experimenter. Even for experts, this process is anything but simple: there are not only many design choices and data properties that need to be considered, but they are also tightly interconnected, rendering reliable manual pipeline optimization all but impossible! ![nnU-Net overview](documentation/assets/nnU-Net_overview.png) **nnU-Net is a semantic segmentation method that automatically adapts to a given dataset. It will analyze the provided training cases and automatically configure a matching U-Net-based segmentation pipeline. No expertise required on your end! You can simply train the models and use them for your application**. Upon release, nnU-Net was evaluated on 23 datasets belonging to competitions from the biomedical domain. Despite competing with handcrafted solutions for each respective dataset, nnU-Net's fully automated pipeline scored several first places on open leaderboards! Since then nnU-Net has stood the test of time: it continues to be used as a baseline and method development framework ([9 out of 10 challenge winners at MICCAI 2020](https://arxiv.org/abs/2101.00232) and 5 out of 7 in MICCAI 2021 built their methods on top of nnU-Net, [we won AMOS2022 with nnU-Net](https://amos22.grand-challenge.org/final-ranking/))! Please cite the [following paper](https://www.google.com/url?q=https://www.nature.com/articles/s41592-020-01008-z&sa=D&source=docs&ust=1677235958581755&usg=AOvVaw3dWL0SrITLhCJUBiNIHCQO) when using nnU-Net: Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. ## What can nnU-Net do for you? If you are a **domain scientist** (biologist, radiologist, ...) looking to analyze your own images, nnU-Net provides an out-of-the-box solution that is all but guaranteed to provide excellent results on your individual dataset. Simply convert your dataset into the nnU-Net format and enjoy the power of AI - no expertise required! If you are an **AI researcher** developing segmentation methods, nnU-Net: - offers a fantastic out-of-the-box applicable baseline algorithm to compete against - can act as a method development framework to test your contribution on a large number of datasets without having to tune individual pipelines (for example evaluating a new loss function) - provides a strong starting point for further dataset-specific optimizations. This is particularly used when competing in segmentation challenges - provides a new perspective on the design of segmentation methods: maybe you can find better connections between dataset properties and best-fitting segmentation pipelines? ## What is the scope of nnU-Net? nnU-Net is built for semantic segmentation. It can handle 2D and 3D images with arbitrary input modalities/channels. It can understand voxel spacings, anisotropies and is robust even when classes are highly imbalanced. nnU-Net relies on supervised learning, which means that you need to provide training cases for your application. The number of required training cases varies heavily depending on the complexity of the segmentation problem. No one-fits-all number can be provided here! nnU-Net does not require more training cases than other solutions - maybe even less due to our extensive use of data augmentation. nnU-Net expects to be able to process entire images at once during preprocessing and postprocessing, so it cannot handle enormous images. As a reference: we tested images from 40x40x40 pixels all the way up to 1500x1500x1500 in 3D and 40x40 up to ~30000x30000 in 2D! If your RAM allows it, larger is always possible. ## How does nnU-Net work? Given a new dataset, nnU-Net will systematically analyze the provided training cases and create a 'dataset fingerprint'. nnU-Net then creates several U-Net configurations for each dataset: - `2d`: a 2D U-Net (for 2D and 3D datasets) - `3d_fullres`: a 3D U-Net that operates on a high image resolution (for 3D datasets only) - `3d_lowres` → `3d_cascade_fullres`: a 3D U-Net cascade where first a 3D U-Net operates on low resolution images and then a second high-resolution 3D U-Net refined the predictions of the former (for 3D datasets with large image sizes only) **Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net already covers a large part of the input images.** nnU-Net configures its segmentation pipelines based on a three-step recipe: - **Fixed parameters** are not adapted. During development of nnU-Net we identified a robust configuration (that is, certain architecture and training properties) that can simply be used all the time. This includes, for example, nnU-Net's loss function, (most of the) data augmentation strategy and learning rate. - **Rule-based parameters** use the dataset fingerprint to adapt certain segmentation pipeline properties by following hard-coded heuristic rules. For example, the network topology (pooling behavior and depth of the network architecture) are adapted to the patch size; the patch size, network topology and batch size are optimized jointly given some GPU memory constraint. - **Empirical parameters** are essentially trial-and-error. For example the selection of the best U-net configuration for the given dataset (2D, 3D full resolution, 3D low resolution, 3D cascade) and the optimization of the postprocessing strategy. ## How to get started? Read these: - [Installation instructions](documentation/installation_instructions.md) - [Dataset conversion](documentation/dataset_format.md) - [Usage instructions](documentation/how_to_use_nnunet.md) Additional information: - [Contributing to nnU-Net](CONTRIBUTING.md) - [Learning from sparse annotations (scribbles, slices)](documentation/ignore_label.md) - [Region-based training](documentation/region_based_training.md) - [Manual data splits](documentation/manual_data_splits.md) - [Pretraining and finetuning](documentation/pretraining_and_finetuning.md) - [Intensity Normalization in nnU-Net](documentation/explanation_normalization.md) - [Training logging (Local + Weights & Biases)](documentation/explanation_logging.md) - [Manually editing nnU-Net configurations](documentation/explanation_plans_files.md) - [Extending nnU-Net](documentation/extending_nnunet.md) - [What is different in V2?](documentation/changelog.md) Competitions: - [AutoPET II](documentation/competitions/AutoPETII.md) [//]: # (- [Ignore label](documentation/ignore_label.md)) ## Where does nnU-Net perform well and where does it not perform? nnU-Net excels in segmentation problems that need to be solved by training from scratch, for example: research applications that feature non-standard image modalities and input channels, challenge datasets from the biomedical domain, majority of 3D segmentation problems, etc . We have yet to find a dataset for which nnU-Net's working principle fails! Note: On standard segmentation problems, such as 2D RGB images in ADE20k and Cityscapes, fine-tuning a foundation model (that was pretrained on a large corpus of similar images, e.g. Imagenet 22k, JFT-300M) will provide better performance than nnU-Net! That is simply because these models allow much better initialization. Foundation models are not supported by nnU-Net as they 1) are not useful for segmentation problems that deviate from the standard setting (see above mentioned datasets), 2) would typically only support 2D architectures and 3) conflict with our core design principle of carefully adapting the network topology for each dataset (if the topology is changed one can no longer transfer pretrained weights!) ## What happened to the old nnU-Net? The core of the old nnU-Net was hacked together in a short time period while participating in the Medical Segmentation Decathlon challenge in 2018. Consequently, code structure and quality were not the best. Many features were added later on and didn't quite fit into the nnU-Net design principles. Overall quite messy, really. And annoying to work with. nnU-Net V2 is a complete overhaul. The "delete everything and start again" kind. So everything is better (in the author's opinion haha). While the segmentation performance [remains the same](https://docs.google.com/spreadsheets/d/13gqjIKEMPFPyMMMwA1EML57IyoBjfC3-QCTn4zRN_Mg/edit?usp=sharing), a lot of cool stuff has been added. It is now also much easier to use it as a development framework and to manually fine-tune its configuration to new datasets. A big driver for the reimplementation was also the emergence of [Helmholtz Imaging](http://helmholtz-imaging.de), prompting us to extend nnU-Net to more image formats and domains. Take a look [here](documentation/changelog.md) for some highlights. # Acknowledgements nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of [Helmholtz Imaging](http://helmholtz-imaging.de) and the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) at the [German Cancer Research Center (DKFZ)](https://www.dkfz.de/en/index.html). ================================================ FILE: setup.py ================================================ import setuptools if __name__ == "__main__": setuptools.setup()