Repository: MIC-DKFZ/nnUNet
Branch: master
Commit: 4e8361bd9f2a
Files: 237
Total size: 1.1 MB
Directory structure:
gitextract_vkpapda3/
├── .github/
│ └── workflows/
│ ├── codespell.yml
│ └── run_tests_nnunet.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── documentation/
│ ├── __init__.py
│ ├── benchmarking.md
│ ├── changelog.md
│ ├── competitions/
│ │ ├── AortaSeg24.md
│ │ ├── AutoPETII.md
│ │ ├── FLARE24/
│ │ │ ├── Task_1/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── inference_flare_task1.py
│ │ │ │ └── readme.md
│ │ │ ├── Task_2/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── inference_flare_task2.py
│ │ │ │ └── readme.md
│ │ │ └── __init__.py
│ │ ├── Toothfairy2/
│ │ │ ├── __init__.py
│ │ │ ├── inference_script_semseg_only_customInf2.py
│ │ │ └── readme.md
│ │ └── __init__.py
│ ├── convert_msd_dataset.md
│ ├── dataset_format.md
│ ├── dataset_format_inference.md
│ ├── explanation_logging.md
│ ├── explanation_normalization.md
│ ├── explanation_plans_files.md
│ ├── extending_nnunet.md
│ ├── how_to_use_nnunet.md
│ ├── ignore_label.md
│ ├── installation_instructions.md
│ ├── manual_data_splits.md
│ ├── pretraining_and_finetuning.md
│ ├── region_based_training.md
│ ├── resenc_presets.md
│ ├── run_inference_with_pretrained_models.md
│ ├── set_environment_variables.md
│ ├── setting_up_paths.md
│ └── tldr_migration_guide_from_v1.md
├── nnunetv2/
│ ├── __init__.py
│ ├── batch_running/
│ │ ├── __init__.py
│ │ ├── benchmarking/
│ │ │ ├── __init__.py
│ │ │ ├── generate_benchmarking_commands.py
│ │ │ └── summarize_benchmark_results.py
│ │ ├── collect_results_custom_Decathlon.py
│ │ ├── collect_results_custom_Decathlon_2d.py
│ │ ├── generate_lsf_runs_customDecathlon.py
│ │ ├── jobs.sh
│ │ └── release_trainings/
│ │ ├── __init__.py
│ │ └── nnunetv2_v1/
│ │ ├── __init__.py
│ │ ├── collect_results.py
│ │ └── generate_lsf_commands.py
│ ├── configuration.py
│ ├── dataset_conversion/
│ │ ├── Dataset015_018_RibFrac_RibSeg.py
│ │ ├── Dataset021_CTAAorta.py
│ │ ├── Dataset023_AbdomenAtlas1_1Mini.py
│ │ ├── Dataset027_ACDC.py
│ │ ├── Dataset042_BraTS18.py
│ │ ├── Dataset043_BraTS19.py
│ │ ├── Dataset073_Fluo_C3DH_A549_SIM.py
│ │ ├── Dataset114_MNMs.py
│ │ ├── Dataset115_EMIDEC.py
│ │ ├── Dataset119_ToothFairy2_All.py
│ │ ├── Dataset120_RoadSegmentation.py
│ │ ├── Dataset137_BraTS21.py
│ │ ├── Dataset218_Amos2022_task1.py
│ │ ├── Dataset219_Amos2022_task2.py
│ │ ├── Dataset220_KiTS2023.py
│ │ ├── Dataset221_AutoPETII_2023.py
│ │ ├── Dataset223_AMOS2022postChallenge.py
│ │ ├── Dataset224_AbdomenAtlas1.0.py
│ │ ├── Dataset226_BraTS2024-BraTS-GLI.py
│ │ ├── Dataset227_TotalSegmentatorMRI.py
│ │ ├── Dataset987_dummyDataset4.py
│ │ ├── Dataset989_dummyDataset4_2.py
│ │ ├── __init__.py
│ │ ├── convert_MSD_dataset.py
│ │ ├── convert_raw_dataset_from_old_nnunet_format.py
│ │ ├── datasets_for_integration_tests/
│ │ │ ├── Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
│ │ │ ├── Dataset997_IntegrationTest_Hippocampus_regions.py
│ │ │ ├── Dataset998_IntegrationTest_Hippocampus_ignore.py
│ │ │ ├── Dataset999_IntegrationTest_Hippocampus.py
│ │ │ └── __init__.py
│ │ └── generate_dataset_json.py
│ ├── ensembling/
│ │ ├── __init__.py
│ │ └── ensemble.py
│ ├── evaluation/
│ │ ├── __init__.py
│ │ ├── accumulate_cv_results.py
│ │ ├── evaluate_predictions.py
│ │ └── find_best_configuration.py
│ ├── experiment_planning/
│ │ ├── __init__.py
│ │ ├── dataset_fingerprint/
│ │ │ ├── __init__.py
│ │ │ └── fingerprint_extractor.py
│ │ ├── experiment_planners/
│ │ │ ├── __init__.py
│ │ │ ├── default_experiment_planner.py
│ │ │ ├── network_topology.py
│ │ │ ├── resampling/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── planners_no_resampling.py
│ │ │ │ └── resample_with_torch.py
│ │ │ ├── resencUNet_planner.py
│ │ │ └── residual_unets/
│ │ │ ├── __init__.py
│ │ │ └── residual_encoder_unet_planners.py
│ │ ├── plan_and_preprocess_api.py
│ │ ├── plan_and_preprocess_entrypoints.py
│ │ ├── plans_for_pretraining/
│ │ │ ├── __init__.py
│ │ │ └── move_plans_between_datasets.py
│ │ └── verify_dataset_integrity.py
│ ├── imageio/
│ │ ├── __init__.py
│ │ ├── base_reader_writer.py
│ │ ├── natural_image_reader_writer.py
│ │ ├── nibabel_reader_writer.py
│ │ ├── reader_writer_registry.py
│ │ ├── readme.md
│ │ ├── simpleitk_reader_writer.py
│ │ └── tif_reader_writer.py
│ ├── inference/
│ │ ├── JHU_inference.py
│ │ ├── __init__.py
│ │ ├── data_iterators.py
│ │ ├── examples.py
│ │ ├── export_prediction.py
│ │ ├── predict_from_raw_data.py
│ │ ├── readme.md
│ │ └── sliding_window_prediction.py
│ ├── model_sharing/
│ │ ├── __init__.py
│ │ ├── entry_points.py
│ │ ├── model_download.py
│ │ ├── model_export.py
│ │ └── model_import.py
│ ├── paths.py
│ ├── postprocessing/
│ │ ├── __init__.py
│ │ └── remove_connected_components.py
│ ├── preprocessing/
│ │ ├── __init__.py
│ │ ├── cropping/
│ │ │ ├── __init__.py
│ │ │ └── cropping.py
│ │ ├── normalization/
│ │ │ ├── __init__.py
│ │ │ ├── default_normalization_schemes.py
│ │ │ ├── map_channel_name_to_normalization.py
│ │ │ └── readme.md
│ │ ├── preprocessors/
│ │ │ ├── __init__.py
│ │ │ └── default_preprocessor.py
│ │ └── resampling/
│ │ ├── __init__.py
│ │ ├── default_resampling.py
│ │ ├── no_resampling.py
│ │ ├── resample_torch.py
│ │ └── utils.py
│ ├── run/
│ │ ├── __init__.py
│ │ ├── load_pretrained_weights.py
│ │ └── run_training.py
│ ├── tests/
│ │ ├── __init__.py
│ │ └── integration_tests/
│ │ ├── __init__.py
│ │ ├── add_lowres_and_cascade.py
│ │ ├── cleanup_integration_test.py
│ │ ├── lsf_commands.sh
│ │ ├── prepare_integration_tests.sh
│ │ ├── readme.md
│ │ ├── run_integration_test.sh
│ │ ├── run_integration_test_bestconfig_inference.py
│ │ ├── run_integration_test_trainingOnly_DDP.sh
│ │ └── run_nnunet_inference.py
│ ├── training/
│ │ ├── __init__.py
│ │ ├── data_augmentation/
│ │ │ ├── __init__.py
│ │ │ ├── compute_initial_patch_size.py
│ │ │ └── custom_transforms/
│ │ │ ├── __init__.py
│ │ │ ├── cascade_transforms.py
│ │ │ ├── deep_supervision_donwsampling.py
│ │ │ ├── masking.py
│ │ │ ├── region_based_training.py
│ │ │ └── transforms_for_dummy_2d.py
│ │ ├── dataloading/
│ │ │ ├── __init__.py
│ │ │ ├── data_loader.py
│ │ │ ├── nnunet_dataset.py
│ │ │ └── utils.py
│ │ ├── logging/
│ │ │ ├── __init__.py
│ │ │ └── nnunet_logger.py
│ │ ├── loss/
│ │ │ ├── __init__.py
│ │ │ ├── compound_losses.py
│ │ │ ├── deep_supervision.py
│ │ │ ├── dice.py
│ │ │ └── robust_ce_loss.py
│ │ ├── lr_scheduler/
│ │ │ ├── __init__.py
│ │ │ ├── polylr.py
│ │ │ └── warmup.py
│ │ └── nnUNetTrainer/
│ │ ├── __init__.py
│ │ ├── nnUNetTrainer.py
│ │ ├── primus/
│ │ │ ├── __init__.py
│ │ │ └── primus_trainers.py
│ │ └── variants/
│ │ ├── __init__.py
│ │ ├── benchmarking/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerBenchmark_5epochs.py
│ │ │ └── nnUNetTrainerBenchmark_5epochs_noDataLoading.py
│ │ ├── competitions/
│ │ │ ├── __init__.py
│ │ │ └── aortaseg24.py
│ │ ├── data_augmentation/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerDA5.py
│ │ │ ├── nnUNetTrainerDAOrd0.py
│ │ │ ├── nnUNetTrainerNoDA.py
│ │ │ ├── nnUNetTrainerNoMirroring.py
│ │ │ └── nnUNetTrainer_noDummy2DDA.py
│ │ ├── loss/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerCELoss.py
│ │ │ ├── nnUNetTrainerDiceLoss.py
│ │ │ └── nnUNetTrainerTopkLoss.py
│ │ ├── lr_schedule/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerCosAnneal.py
│ │ │ └── nnUNetTrainer_warmup.py
│ │ ├── network_architecture/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerBN.py
│ │ │ └── nnUNetTrainerNoDeepSupervision.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ ├── nnUNetTrainerAdam.py
│ │ │ └── nnUNetTrainerAdan.py
│ │ ├── sampling/
│ │ │ ├── __init__.py
│ │ │ └── nnUNetTrainer_probabilisticOversampling.py
│ │ └── training_length/
│ │ ├── __init__.py
│ │ ├── nnUNetTrainer_Xepochs.py
│ │ └── nnUNetTrainer_Xepochs_NoMirroring.py
│ └── utilities/
│ ├── __init__.py
│ ├── collate_outputs.py
│ ├── crossval_split.py
│ ├── dataset_name_id_conversion.py
│ ├── ddp_allgather.py
│ ├── default_n_proc_DA.py
│ ├── file_path_utilities.py
│ ├── find_class_by_name.py
│ ├── get_network_from_plans.py
│ ├── helpers.py
│ ├── json_export.py
│ ├── label_handling/
│ │ ├── __init__.py
│ │ └── label_handling.py
│ ├── network_initialization.py
│ ├── overlay_plots.py
│ ├── plans_handling/
│ │ ├── __init__.py
│ │ └── plans_handler.py
│ └── utils.py
├── pyproject.toml
├── readme.md
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/codespell.yml
================================================
---
name: Codespell
on:
push:
branches: [master]
pull_request:
branches: [master]
permissions:
contents: read
jobs:
codespell:
name: Check for spelling errors
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Codespell
uses: codespell-project/actions-codespell@v2
================================================
FILE: .github/workflows/run_tests_nnunet.yml
================================================
name: Run nnunet tests on all OS
on:
push:
paths-ignore:
- '**.md'
jobs:
run-tests:
strategy:
matrix:
# os: [ubuntu-latest, windows-latest, macos-latest] # fails on windows until https://github.com/MIC-DKFZ/nnUNet/issues/2396 is resolved
os: [ubuntu-latest, macos-latest]
python-version: ["3.10"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Download and extract weights and download example file
# use python instead of curl to avoid the need to install curl (more difficult for macos)
run: |
mkdir -p $HOME/github_actions_nnunet/results
python -c "import urllib.request; urllib.request.urlretrieve('https://github.com/wasserth/TotalSegmentator/releases/download/v2.0.0-weights/Dataset300_body_6mm_1559subj.zip', '$HOME/github_actions_nnunet/results/tmp_download_file.zip')"
unzip -o $HOME/github_actions_nnunet/results/tmp_download_file.zip -d $HOME/github_actions_nnunet/results
rm $HOME/github_actions_nnunet/results/tmp_download_file.zip
- name: Install dependencies on Ubuntu
if: runner.os == 'Linux'
run: |
python -m pip install --upgrade pip
pip install pytest Cython
pip install torch==2.4.0 -f https://download.pytorch.org/whl/cpu
pip install .
- name: Install dependencies on Windows / MacOS
if: runner.os == 'Windows' || runner.os == 'macOS'
run: |
python -m pip install --upgrade pip
pip install pytest Cython
pip install torch==2.4.0
pip install .
- name: Run test script
run: python nnunetv2/tests/integration_tests/run_nnunet_inference.py
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
*.memmap
*.png
*.zip
*.npz
*.npy
*.jpg
*.jpeg
.idea
*.txt
.idea/*
*.png
# *.nii.gz # nifti files needed for example_data for github actions tests
*.nii
*.tif
*.bmp
*.pkl
*.xml
*.pkl
*.pdf
*.png
*.jpg
*.jpeg
*.model
!documentation/assets/scribble_example.png
CLAUDE.md
AGENTS.md
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to nnU-Net
Thank you for your interest in contributing to nnU-Net.
nnU-Net is developed and maintained by researchers at DKFZ. There is no dedicated funding or staff for maintaining the
repository, and development happens alongside research and teaching responsibilities. Our bandwidth for reviewing
external contributions is therefore limited, and review times may be long.
## General principles
nnU-Net is intentionally designed to be focused, stable, and generally applicable across datasets and use cases.
Contributions should respect this philosophy and should not introduce unnecessary complexity or specialization.
New functionality must either be generally valid across datasets and setups or convincingly benefit a large enough
portion of the user base. We aim to avoid bloating the framework or increasing its complexity further.
## How to contribute
For larger features and refactors, please open a GitHub issue to discuss the idea before starting work. Tag
@FabianIsensee so that the discussion doesn't get missed.
To submit a contribution, fork the repository, make your changes on a branch, and open a pull request.
## Bug reports and bug fixes
Bug reports must include a minimal reproducible example. Without a repro, it is usually impossible for us to
investigate issues.
Pull requests fixing bugs should also include a clear reproduction of the issue and an explanation of how the fix
resolves it.
## Performance improvements
If a pull request claims performance improvements, it must include benchmarks demonstrating the effect. The benchmark
setup must be described clearly enough for us to reproduce the results independently. We may run additional
tests ourselves before merging.
## Contributions that are unlikely to be merged
To keep the framework maintainable and the workload manageable on our end, we deprioritize:
- dataset-specific code
- features that only apply to niche setups
- narrow custom architectures or training pipelines
- large refactorings without prior discussion
- small PRs fixing minor typos or formatting issues
## Final note
We appreciate the effort people invest in improving nnU-Net!
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: documentation/__init__.py
================================================
================================================
FILE: documentation/benchmarking.md
================================================
# nnU-Netv2 benchmarks
Does your system run like it should? Is your epoch time longer than expected? What epoch times should you expect?
Look no further for we have the solution here!
## What does the nnU-netv2 benchmark do?
nnU-Net's benchmark trains models for 5 epochs. At the end, the fastest epoch will
be noted down, along with the GPU name, torch version and cudnn version. You can find the benchmark output in the
corresponding nnUNet_results subfolder (see example below). Don't worry, we also provide scripts to collect your
results. Or you just start a benchmark and look at the console output. Everything is possible. Nothing is forbidden.
The benchmark implementation revolves around two trainers:
- `nnUNetTrainerBenchmark_5epochs` runs a regular training for 5 epochs. When completed, writes a .json file with the fastest
epoch time as well as the GPU used and the torch and cudnn versions. Useful for speed testing the entire pipeline
(data loading, augmentation, GPU training)
- `nnUNetTrainerBenchmark_5epochs_noDataLoading` is the same, but it doesn't do any data loading or augmentation. It
just presents dummy arrays to the GPU. Useful for checking pure GPU speed.
## How to run the nnU-Netv2 benchmark?
It's quite simple, actually. It looks just like a regular nnU-Net training.
We provide reference numbers for some of the Medical Segmentation Decathlon datasets because they are easily
accessible: [download here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2). If it needs to be
quick and dirty, focus on Tasks 2 and 4. Download and extract the data and convert them to the nnU-Net format with
`nnUNetv2_convert_MSD_dataset`.
Run `nnUNetv2_plan_and_preprocess` for them.
Then, for each dataset, run the following commands (only one per GPU! Or one after the other):
```bash
nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs
nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs
nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading
nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading
```
If you want to inspect the outcome manually, check (for example!) your
`nnUNet_results/DATASET_NAME/nnUNetTrainerBenchmark_5epochs__nnUNetPlans__3d_fullres/fold_0/` folder for the `benchmark_result.json` file.
Note that there can be multiple entries in this file if the benchmark was run on different GPU types, torch versions or cudnn versions!
If you want to summarize your results like we did in our [results](#results), check the
[summary script](../nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py). Here you need to change the
torch version, cudnn version and dataset you want to summarize, then execute the script. You can find the exact
values you need to put there in one of your `benchmark_result.json` files.
## Results
We have tested a variety of GPUs and summarized the results in a
[spreadsheet](https://docs.google.com/spreadsheets/d/12Cvt_gr8XU2qWaE0XJk5jJlxMEESPxyqW0CWbQhTNNY/edit?usp=sharing).
Note that you can select the torch and cudnn versions at the bottom! There may be comments in this spreadsheet. Read them!
## Result interpretation
Results are shown as epoch time in seconds. Lower is better (duh). Epoch times can fluctuate between runs, so as
long as you are within like 5-10% of the numbers we report, everything should be dandy.
If not, here is how you can try to find the culprit!
The first thing to do is to compare the performance between the `nnUNetTrainerBenchmark_5epochs_noDataLoading` and
`nnUNetTrainerBenchmark_5epochs` trainers. If the difference is about the same as we report in our spreadsheet, but
both your numbers are worse, the problem is with your GPU:
- Are you certain you compare the correct GPU? (duh)
- If yes, then you might want to install PyTorch in a different way. Never `pip install torch`! Go to the
[PyTorch installation](https://pytorch.org/get-started/locally/) page, select the most recent cuda version your
system supports and only then copy and execute the correct command! Either pip or conda should work
- If the problem is still not fixed, we recommend you try
[compiling pytorch from source](https://github.com/pytorch/pytorch#from-source). It's more difficult but that's
how we roll here at the DKFZ (at least the cool kids here).
- Another thing to consider is to try exactly the same torch + cudnn version as we did in our spreadsheet.
Sometimes newer versions can actually degrade performance and there might be bugs from time to time. Older versions
are also often a lot slower!
- Finally, some very basic things that could impact your GPU performance:
- Is the GPU cooled adequately? Check the temperature with `nvidia-smi`. Hot GPUs throttle performance in order to not self-destruct
- Is your OS using the GPU for displaying your desktop at the same time? If so then you can expect a performance
penalty (I dunno like 10% !?). That's expected and OK.
- Are other users using the GPU as well?
If you see a large performance difference between `nnUNetTrainerBenchmark_5epochs_noDataLoading` (fast) and
`nnUNetTrainerBenchmark_5epochs` (slow) then the problem might be related to data loading and augmentation. As a
reminder, nnU-net does not use pre-augmented images (offline augmentation) but instead generates augmented training
samples on the fly during training (no, you cannot switch it to offline). This requires that your system can do partial
reads of the image files fast enough (SSD storage required!) and that your CPU is powerful enough to run the augmentations.
Check the following:
- [CPU bottleneck] How many CPU threads are running during the training? nnU-Net uses 12 processes for data augmentation by default.
If you see those 12 running constantly during training, consider increasing the number of processes used for data
augmentation (provided there is headroom on your CPU!). Increase the number until you see less active workers than
you configured (or just set the number to 32 and forget about it). You can do so by setting the `nnUNet_n_proc_DA`
environment variable (Linux: `export nnUNet_n_proc_DA=24`). Read [here](set_environment_variables.md) on how to do this.
If your CPU does not support more processes (setting more processes than your CPU has threads makes
no sense!) you are out of luck and in desperate need of a system upgrade!
- [I/O bottleneck] If you don't see 12 (or nnUNet_n_proc_DA if you set it) processes running but your training times
are still slow then open up `top` (sorry, Windows users. I don't know how to do this on Windows) and look at the value
left of 'wa' in the row that begins
with '%Cpu (s)'. If this is >1.0 (arbitrarily set threshold here, essentially look for unusually high 'wa'. In a
healthy training 'wa' will be almost 0) then your storage cannot keep up with data loading. Make sure to set
nnUNet_preprocessed to a folder that is located on an SSD. nvme is preferred over SATA. PCIe3 is enough. 3000MB/s
sequential read recommended.
- [funky stuff] Sometimes there is funky stuff going on, especially when batch sizes are large, files are small and
patch sizes are small as well. As part of the data loading process, nnU-Net needs to open and close a file for each
training sample. Now imagine a dataset like Dataset004_Hippocampus where for the 2d config we have a batch size of
366 and we run 250 iterations in <10s on an A100. That's a lotta files per second (366 * 250 / 10 = 9150 files per second).
Oof. If the files are on some network drive (even if it's nvme) then (probably) good night. The good news: nnU-Net
has got you covered: add `export nnUNet_keep_files_open=True` to your .bashrc and the problem goes away. The neat
part: it causes new problems if you are not allowed to have enough open files. You may have to increase the number
of allowed open files. `ulimit -n` gives your current limit (Linux only). It should not be something like 1024.
Increasing that to 65535 works well for me. See here for how to change these limits:
[Link](https://kupczynski.info/posts/ubuntu-18-10-ulimits/)
(works for Ubuntu 18, google for your OS!).
================================================
FILE: documentation/changelog.md
================================================
# What is different in v2?
- We now support **hierarchical labels** (named regions in nnU-Net). For example, instead of training BraTS with the
'edema', 'necrosis' and 'enhancing tumor' labels you can directly train it on the target areas 'whole tumor',
'tumor core' and 'enhancing tumor'. See [here](region_based_training.md) for a detailed description + also have a look at the
[BraTS 2021 conversion script](../nnunetv2/dataset_conversion/Dataset137_BraTS21.py).
- Cross-platform support. Cuda, mps (Apple M1/M2) and of course CPU support! Simply select the device with
`-device` in `nnUNetv2_train` and `nnUNetv2_predict`.
- Unified trainer class: nnUNetTrainer. No messing around with cascaded trainer, DDP trainer, region-based trainer,
ignore trainer etc. All default functionality is in there!
- Supports more input/output data formats through ImageIO classes.
- I/O formats can be extended by implementing new Adapters based on `BaseReaderWriter`.
- The nnUNet_raw_cropped folder no longer exists -> saves disk space at no performance penalty. magic! (no jk the
saving of cropped npz files was really slow, so it's actually faster to crop on the fly).
- Preprocessed data and segmentation are stored in different files when unpacked. Seg is stored as int8 and thus
takes 1/4 of the disk space per pixel (and I/O throughput) as in v1.
- Native support for multi-GPU (DDP) TRAINING.
Multi-GPU INFERENCE should still be run with `CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] -num_parts Y -part_id X`.
There is no cross-GPU communication in inference, so it doesn't make sense to add additional complexity with DDP.
- All nnU-Net functionality is now also accessible via API. Check the corresponding entry point in `setup.py` to see
what functions you need to call.
- Dataset fingerprint is now explicitly created and saved in a json file (see nnUNet_preprocessed).
- Complete overhaul of plans files (read also [this](explanation_plans_files.md):
- Plans are now .json and can be opened and read more easily
- Configurations are explicitly named ("3d_fullres" , ...)
- Configurations can inherit from each other to make manual experimentation easier
- A ton of additional functionality is now included in and can be changed through the plans, for example normalization strategy, resampling etc.
- Stages of the cascade are now explicitly listed in the plans. 3d_lowres has 'next_stage' (which can also be a
list of configurations!). 3d_cascade_fullres has a 'previous_stage' entry. By manually editing plans files you can
now connect anything you want, for example 2d with 3d_fullres or whatever. Be wild! (But don't create cycles!)
- Multiple configurations can point to the same preprocessed data folder to save disk space. Careful! Only
configurations that use the same spacing, resampling, normalization etc. should share a data source! By default,
3d_fullres and 3d_cascade_fullres share the same data
- Any number of configurations can be added to the plans (remember to give them a unique "data_identifier"!)
Folder structures are different and more user-friendly:
- nnUNet_preprocessed
- By default, preprocessed data is now saved as: `nnUNet_preprocessed/DATASET_NAME/PLANS_IDENTIFIER_CONFIGURATION` to clearly link them to their corresponding plans and configuration
- Name of the folder containing the preprocessed images can be adapted with the `data_identifier` key.
- nnUNet_results
- Results are now sorted as follows: DATASET_NAME/TRAINERCLASS__PLANSIDENTIFIER__CONFIGURATION/FOLD
## What other changes are planned and not yet implemented?
- Integration into MONAI (together with our friends at Nvidia)
- New pretrained weights for a large number of datasets (coming very soon))
[//]: # (- nnU-Net now also natively supports an **ignore label**. Pixels with this label will not contribute to the loss. )
[//]: # (Use this to learn from sparsely annotated data, or excluding irrelevant areas from training. Read more [here](ignore_label.md).)
================================================
FILE: documentation/competitions/AortaSeg24.md
================================================
Authors: \
Maximilian Rokuss, Michael Baumgartner, Yannick Kirchhoff, Klaus H. Maier-Hein*, Fabian Isensee*
*: equal contribution
Author Affiliations:\
Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \
Helmholtz Imaging
# Introduction
This document describes our submission to the [AortaSeg24 Challenge](https://aortaseg24.grand-challenge.org/).
Our model is essentially a nnU-Net ResEnc L with modified data augmentation. We disable left/right mirroring and use the heavy data augmentation [DA5 Trainer](../../nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py). Training was performed on an A100 40GB GPU.
# Experiment Planning and Preprocessing
After converting the data into the [nnUNet format](../../../nnUNet/documentation/dataset_format.md) (either keep and just rename the .mha files or convert them to .nii.gz), you can run the preprocessing:
```bash
nnUNetv2_plan_and_preprocess -d 610 -c 3d_fullres -pl nnUNetPlannerResEncL -np 16
```
# Training
We train our model using:
```bash
nnUNetv2_train 610 3d_fullres all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_DA5
```
Models are trained from scratch. We train one model using all the images and a five fold cross validation ensemble for the submission.
We recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks.
Use `export nnUNet_n_proc_DA=32` or higher (if your system permits!).
# Inference
For inference you can use the default [nnUNet inference functionalities](../../../nnUNet/documentation/how_to_use_nnunet.md). Specifically, once the training is finished, run:
```bash
nnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER -f all
```
for the single model trained on all the data and
```bash
nnUNetv2_predict_from_modelfolder -i INPUT_FOLDER -o OUTPUT_FOLDER -m MODEL_FOLDER
```
for the five fold ensemble.
================================================
FILE: documentation/competitions/AutoPETII.md
================================================
# Look Ma, no code: fine tuning nnU-Net for the AutoPET II challenge by only adjusting its JSON plans
Please cite our paper :-*
```text
COMING SOON
```
## Intro
See the [Challenge Website](https://autopet-ii.grand-challenge.org/) for details on the challenge.
Our solution to this challenge rewuires no code changes at all. All we do is optimize nnU-Net's hyperparameters
(architecture, batch size, patch size) through modifying the nnUNetplans.json file.
## Prerequisites
Use the latest pytorch version!
We recommend you use the latest nnU-Net version as well! We ran our trainings with commit 913705f which you can try in case something doesn't work as expected:
`pip install git+https://github.com/MIC-DKFZ/nnUNet.git@913705f`
## How to reproduce our trainings
### Download and convert the data
1. Download and extract the AutoPET II dataset
2. Convert it to nnU-Net format by running `python nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py FOLDER` where folder is the extracted AutoPET II dataset.
### Experiment planning and preprocessing
We deviate a little from the standard nnU-Net procedure because all our experiments are based on just the 3d_fullres configuration
Run the following commands:
- `nnUNetv2_extract_fingerprint -d 221` extracts the dataset fingerprint
- `nnUNetv2_plan_experiment -d 221` does the planning for the plain unet
- `nnUNetv2_plan_experiment -d 221 -pl ResEncUNetPlanner` does the planning for the residual encoder unet
- `nnUNetv2_preprocess -d 221 -c 3d_fullres` runs all the preprocessing we need
### Modification of plans files
Please read the [information on how to modify plans files](../explanation_plans_files.md) first!!!
It is easier to have everything in one plans file, so the first thing we do is transfer the ResEnc UNet to the
default plans file. We use the configuration inheritance feature of nnU-Net to make it use the same data as the
3d_fullres configuration.
Add the following to the 'configurations' dict in 'nnUNetPlans.json':
```json
"3d_fullres_resenc": {
"inherits_from": "3d_fullres",
"network_arch_class_name": "ResidualEncoderUNet",
"n_conv_per_stage_encoder": [
1,
3,
4,
6,
6,
6
],
"n_conv_per_stage_decoder": [
1,
1,
1,
1,
1
]
},
```
(these values are basically just copied from the 'nnUNetResEncUNetPlans.json' file! With everything redundant being omitted thanks to inheritance from 3d_fullres)
Now we crank up the patch and batch sizes. Add the following configurations:
```json
"3d_fullres_resenc_bs80": {
"inherits_from": "3d_fullres_resenc",
"batch_size": 80
},
"3d_fullres_resenc_192x192x192_b24": {
"inherits_from": "3d_fullres_resenc",
"patch_size": [
192,
192,
192
],
"batch_size": 24
}
```
Save the file (and check for potential Syntax Errors!)
### Run trainings
Training each model requires 8 Nvidia A100 40GB GPUs. Expect training to run for 5-7 days. You'll need a really good
CPU to handle the data augmentation! 128C/256T are a must! If you have less threads available, scale down nnUNet_n_proc_DA accordingly.
```bash
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 0 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 1 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 2 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 3 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_bs80 4 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 0 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 1 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 2 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 3 -num_gpus 8
nnUNet_compile=T nnUNet_n_proc_DA=28 nnUNetv2_train 221 3d_fullres_resenc_192x192x192_b24 4 -num_gpus 8
```
Done!
(We also provide pretrained weights in case you don't want to invest the GPU resources, see below)
## How to make predictions with pretrained weights
Our final model is an ensemble of two configurations:
- ResEnc UNet with batch size 80
- ResEnc UNet with patch size 192x192x192 and batch size 24
To run inference with these models, do the following:
1. Download the pretrained model weights from [Zenodo](https://zenodo.org/record/8362371)
2. Install both .zip files using `nnUNetv2_install_pretrained_model_from_zip`
3. Make sure
4. Now you can run inference on new cases with `nnUNetv2_predict`:
- `nnUNetv2_predict -i INPUT -o OUTPUT1 -d 221 -c 3d_fullres_resenc_bs80 -f 0 1 2 3 4 -step_size 0.6 --save_probabilities`
- `nnUNetv2_predict -i INPUT -o OUTPUT2 -d 221 -c 3d_fullres_resenc_192x192x192_b24 -f 0 1 2 3 4 --save_probabilities`
- `nnUNetv2_ensemble -i OUTPUT1 OUTPUT2 -o OUTPUT_ENSEMBLE`
Note that our inference Docker omitted TTA via mirroring along the axial direction during prediction (only sagittal +
coronal mirroring). This was
done to keep the inference time below 10 minutes per image on a T4 GPU (we actually never tested whether we could
have left this enabled). Just leave it on! You can also leave the step_size at default for the 3d_fullres_resenc_bs80.
================================================
FILE: documentation/competitions/FLARE24/Task_1/__init__.py
================================================
================================================
FILE: documentation/competitions/FLARE24/Task_1/inference_flare_task1.py
================================================
from typing import Union, Tuple
import argparse
import numpy as np
import os
from os.path import join
from pathlib import Path
from time import time
import torch
from torch._dynamo import OptimizedModule
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
from batchgenerators.utilities.file_and_folder_operations import load_json
import nnunetv2
from nnunetv2.configuration import default_num_processes
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.nibabel_reader_writer import NibabelIOWithReorient
class FlarePredictor(nnUNetPredictor):
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth'):
"""
This is used when making predictions with a trained model
"""
if use_folds is None:
use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
if isinstance(use_folds, str):
use_folds = [use_folds]
parameters = []
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
map_location=torch.device('cpu'))
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
parameters.append(checkpoint['network_weights'])
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
self.network = torch.compile(self.network)
def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],
plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
label_manager: LabelManager,
properties_dict: dict,
return_probabilities: bool = False,
num_threads_torch: int = default_num_processes):
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
# resample to original shape
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
properties_dict['spacing'])
# return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
# apply_inference_nonlin will convert to torch
# predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
segmentation = predicted_logits.argmax(0)
del predicted_logits
# segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
segmentation_reverted_cropping[slicer] = segmentation
del segmentation
# revert transpose
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)
torch.set_num_threads(old_threads)
return segmentation_reverted_cropping
def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,
configuration_manager: ConfigurationManager,
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
save_probabilities: bool = False):
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
ret = convert_predicted_logits_to_segmentation_with_correct_shape(
predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
return_probabilities=save_probabilities
)
del predicted_array_or_file
segmentation_final = ret
rw = NibabelIOWithReorient()
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
properties_dict)
def predict_flare(input_dir, output_dir, model_folder, folds=("all",)):
input_dir = Path(input_dir)
output_dir = Path(output_dir)
input_files = sorted(input_dir.glob("*.nii.gz"))
output_files = [str(output_dir / f.name[:-12]) for f in input_files]
for input_file, output_file in zip(input_files, output_files):
print(f"Predicting {input_file.name}")
start = time()
plans_manager = PlansManager(load_json(join(model_folder, 'plans.json')))
configuration_manager = plans_manager.get_configuration("3d_fullres")
dataset_json = load_json(join(model_folder, 'dataset.json'))
rw = NibabelIOWithReorient()
image, props = rw.read_images([input_file,])
with torch.no_grad():
predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False)
predictor.initialize_from_trained_model_folder(model_folder, use_folds=folds)
preprocessor = configuration_manager.preprocessor_class(verbose=False)
data, _ = preprocessor.run_case_npy(image,
None,
props,
plans_manager,
configuration_manager,
dataset_json)
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
predicted_logits = predictor.predict_logits_from_preprocessed_data(data).cpu()
export_prediction_from_logits(predicted_logits, props, configuration_manager,
plans_manager, dataset_json, output_file,
False)
print(f"Prediction time: {time() - start:.2f}s")
if __name__ == '__main__':
os.environ['nnUNet_compile'] = 'f'
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", default="/workspace/inputs")
parser.add_argument("-o", "--output", default="/workspace/outputs")
parser.add_argument("-m", "--model", default="/opt/app/_trained_model")
parser.add_argument("-f", "--folds", nargs="+", default=["all"])
args = parser.parse_args()
predict_flare(args.input, args.output, args.model, args.folds)
================================================
FILE: documentation/competitions/FLARE24/Task_1/readme.md
================================================
Authors: \
Yannick Kirchhoff, Maximilian Rouven Rokuss, Benjamin Hamm, Ashis Ravindran, Constantin Ulrich, Klaus Maier-Hein†, Fabian Isensee†
†: equal contribution
# Introduction
This document describes our contribution to [Task 1 of the FLARE24 Challenge](https://www.codabench.org/competitions/2319/).
Our model is basically is a default nnU-Net trained with larger batch size of 4 and 8, respectively. We submitted the batch size 8 model and an ensemble of the batch size 4 and batch size 8 models to the final test set.
# Experiment Planning and Preprocessing
Bring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here:
```json
{
"name": "Dataset301_FLARE24Task1_labeled",
"description": "Pan Cancer Segmentation",
"labels": {
"background": 0,
"lesion": 1
},
"file_ending": ".nii.gz",
"channel_names": {
"0": "CT"
},
"numTraining": 5000
}
```
Afterwards you can run the default nnU-Net planning and preprocessing
```bash
nnUNetv2_plan_and_preprocess -d 301 -c 3d_fullres
```
## Edit the plans files
In the generated `nnUNetPlans.json` file add the following configurations
```json
"3d_fullres_bs4": {
"inherits_from": "3d_fullres",
"batch_size": 4
},
"3d_fullres_bs8": {
"inherits_from": "3d_fullres",
"batch_size": 8
},
"3d_fullres_bs4u8": {
"inherits_from": "3d_fullres",
"batch_size": 48
}
```
Note, the last one is only used for the ensemble model during inference!
# Model training
Run the following commands to train the models with batch size 4 and 8. The large batch size helps stabilize the training despite the partial labels present in the dataset as well as handling the large number of scans in the dataset. We therefore keep the number of epochs at 1000.
```bash
nnUNetv2_train 301 3d_fullres_bs4 all
nnUNetv2_train 301 3d_fullres_bs8 all
```
# Inference
Our inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command!
In order to run inference with the ensemble model you need to create a folder called `nnUNetTrainer__nnUNetPlans__3d_fullres_bs4u8` in the results folder and copy the `dataset.json`, `dataset_fingerprint.json` and `plans.json` from one of the other results folder as well as the `fold_all` from both trainings as `fold_0` and `fold_1`, respectively, into this new folder. This allows for easy ensembling of both models.
To run inference simply run the following commands with `folds` set to `all` for single model inference or `0 1` for the ensemble. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_fullres_bs8`.
```bash
python inference_flare_task1.py -i input_folder -o output_folder -m model_folder -f folds
```
================================================
FILE: documentation/competitions/FLARE24/Task_2/__init__.py
================================================
================================================
FILE: documentation/competitions/FLARE24/Task_2/inference_flare_task2.py
================================================
from typing import Union, List, Tuple
import argparse
import itertools
import multiprocessing
import numpy as np
import os
from os.path import join
from pathlib import Path
from time import time
import torch
from torch._dynamo import OptimizedModule
from tqdm import tqdm
import openvino as ov
import openvino.properties.hint as hints
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from batchgenerators.utilities.file_and_folder_operations import load_json
from nnunetv2.utilities.label_handling.label_handling import LabelManager
import nnunetv2
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy
from nnunetv2.inference.sliding_window_prediction import compute_gaussian
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
torch.set_num_threads(multiprocessing.cpu_count())
class FlarePredictor(nnUNetPredictor):
def __init__(self,
tile_step_size: float = 0.5,
use_gaussian: bool = True,
use_mirroring: bool = True,
perform_everything_on_device: bool = False,
device: torch.device = torch.device('cpu'),
verbose: bool = False,
verbose_preprocessing: bool = False,
allow_tqdm: bool = True,
):
super().__init__(tile_step_size, use_gaussian, use_mirroring, perform_everything_on_device, device, verbose,
verbose_preprocessing, allow_tqdm)
if self.device == torch.device('cuda') or self.device == 'cuda':
raise RuntimeError('CUDA is not supported for this task')
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth',
save_model: bool = True):
"""
This is used when making predictions with a trained model
"""
if use_folds is None:
use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
if isinstance(use_folds, str):
use_folds = [use_folds]
assert len(use_folds) == 1, 'Only one fold is supported for this task'
parameters = []
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
map_location=torch.device('cpu'))
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
if save_model:
parameters.append(checkpoint['network_weights'])
configuration_manager = plans_manager.get_configuration(configuration_name)
if save_model:
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
self.network = network
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.dataset_json = dataset_json
self.label_manager = plans_manager.get_label_manager(dataset_json)
if save_model:
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(parameters[0])
else:
self.network._orig_mod.load_state_dict(parameters[0])
self.network.eval()
config = {hints.performance_mode: hints.PerformanceMode.LATENCY,
hints.enable_cpu_pinning(): True,
}
core = ov.Core()
core.set_property(
"CPU",
{hints.execution_mode: hints.ExecutionMode.PERFORMANCE},
)
if save_model:
input_tensor = torch.randn(1, num_input_channels, *configuration_manager.patch_size, requires_grad=False)
ov_model = ov.convert_model(self.network, example_input=input_tensor)
ov.save_model(ov_model, f"{model_training_output_dir}/model.xml")
import sys
sys.exit(0)
ov_model = core.read_model(f"{model_training_output_dir}/model.xml")
self.network = core.compile_model(ov_model, "CPU", config= config)
def predict_from_files(self,
list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
save_probabilities: bool = False,
overwrite: bool = True,
num_processes_preprocessing: int = 1,
num_processes_segmentation_export: int = 1,
folder_with_segs_from_prev_stage: str = None,
num_parts: int = 1,
part_id: int = 0):
"""
This is nnU-Net's default function for making predictions. It works best for batch predictions
(predicting many images at once).
"""
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
self._manage_input_and_output_lists(list_of_lists_or_source_folder,
output_folder_or_list_of_truncated_output_files,
folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
save_probabilities)
if len(list_of_lists_or_source_folder) == 0:
return
data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
seg_from_prev_stage_files,
output_filename_truncated,
num_processes_preprocessing)
return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
if self.use_openvino:
prediction = torch.from_numpy(self.network(x)[0])
else:
prediction = self.network(x)
if mirror_axes is not None:
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'
mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
if not self.is_openvino:
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
else:
temp_pred = torch.from_numpy(self.network(torch.flip(x, axes))[0])
prediction += torch.flip(temp_pred, axes)
prediction /= (len(axes_combinations) + 1)
return prediction
def _internal_predict_sliding_window_return_logits(self,
data: torch.Tensor,
slicers,
do_on_device: bool = True,
):
predicted_logits = n_predictions = prediction = gaussian = workon = None
results_device = self.device if do_on_device else torch.device('cpu')
try:
empty_cache(self.device)
# move data to device
if self.verbose:
print(f'move image to device {results_device}')
data = data.to(results_device)
# preallocate arrays
if self.verbose:
print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
else:
gaussian = 1
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device)
prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
if self.use_gaussian:
prediction *= gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian
predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
except Exception as e:
del predicted_logits, n_predictions, prediction, gaussian, workon
empty_cache(self.device)
empty_cache(results_device)
raise e
return predicted_logits
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
SEE convert_predicted_logits_to_segmentation_with_correct_shape
"""
prediction = None
if not self.use_openvino:
for params in self.list_of_parameters:
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
else:
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data)
else:
prediction += self.predict_sliding_window_return_logits(data)
if self.verbose: print('Prediction done')
return prediction
@torch.inference_mode()
def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
-> Union[np.ndarray, torch.Tensor]:
assert isinstance(input_image, torch.Tensor)
if self.device not in [torch.device('cpu'), 'cpu']:
self.network = self.network.to(self.device)
self.network.eval()
empty_cache(self.device)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
# and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'
if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits
def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits: Union[torch.Tensor, np.ndarray],
plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
label_manager: LabelManager,
properties_dict: dict,
return_probabilities: bool = False,
num_threads_torch: int = default_num_processes):
# resample to original shape
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
properties_dict['spacing'])
segmentation = predicted_logits.argmax(0)
del predicted_logits
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
segmentation_reverted_cropping[slicer] = segmentation
del segmentation
# revert transpose
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)
return segmentation_reverted_cropping
def export_prediction_from_logits(self, predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,
configuration_manager: ConfigurationManager,
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
save_probabilities: bool = False):
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
ret = self.convert_predicted_logits_to_segmentation_with_correct_shape(
predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
return_probabilities=save_probabilities
)
del predicted_array_or_file
segmentation_final = ret
rw = plans_manager.image_reader_writer_class()
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
properties_dict)
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None,
output_file_truncated: str = None,
save_or_return_probabilities: bool = False):
"""
WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.
input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis
ordering which cannot be disturbed in inference,
otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class
for loading images as was used during nnU-Net preprocessing! You can find that class in your
plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the
default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
image_properties must only have a 'spacing' key!
"""
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
[output_file_truncated],
self.plans_manager, self.dataset_json, self.configuration_manager,
num_threads_in_multithreaded=1, verbose=self.verbose)
if self.verbose:
print('preprocessing')
dct = next(ppa)
if self.verbose:
print('predicting')
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
if self.verbose:
print('resampling to original shape')
self.export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,
self.plans_manager, self.dataset_json, output_file_truncated,
save_or_return_probabilities)
def predict_flare(input_dir, output_dir, model_folder, save_model):
input_dir = Path(input_dir)
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
input_files = list(input_dir.glob("*.nii.gz"))
output_files = [str(output_dir / f.name[:-12]) for f in input_files]
for input_file, output_file in zip(input_files, output_files):
print(f"Predicting {input_file.name}")
start = time()
predictor = FlarePredictor(tile_step_size=0.5, use_mirroring=False, device=torch.device("cpu"))
predictor.initialize_from_trained_model_folder(model_folder, ("all",), save_model=save_model)
rw = predictor.plans_manager.image_reader_writer_class()
image, props = rw.read_images([input_file,])
_ = predictor.predict_single_npy_array(image, props, None, output_file, False)
print(f"Prediction time: {time() - start:.2f}s")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", default="/workspace/inputs")
parser.add_argument("-o", "--output", default="/workspace/outputs")
parser.add_argument("-m", "--model", default="/opt/app/_trained_model")
parser.add_argument("-save_model", action="store_true")
args = parser.parse_args()
predict_flare(args.input, args.output, args.model, args.save_model)
================================================
FILE: documentation/competitions/FLARE24/Task_2/readme.md
================================================
Authors: \
Yannick Kirchhoff*, Ashis Ravindran*, Maximilian Rouven Rokuss, Benjamin Hamm, Constantin Ulrich, Klaus Maier-Hein†, Fabian Isensee†
*: equal contribution \
†: equal contribution
# Introduction
This document describes our contribution to [Task 2 of the FLARE24 Challenge](https://www.codabench.org/competitions/2320/).
Our model is basically is a default nnU-Net with a custom low resolution setting and OpenVINO optimizations for faster CPU inference.
# Experiment Planning and Preprocessing
Bring the downloaded data into the [nnU-Net format](../../../nnUNet/documentation/dataset_format.md) and add the dataset.json file as given here:
```json
{
"name": "Dataset311_FLARE24Task2_labeled",
"description": "Abdominal Organ Segmentation",
"labels": {
"background": 0,
"liver": 1,
"right kidney": 2,
"spleen": 3,
"pancreas": 4,
"aorta": 5,
"ivc": 6,
"rag": 7,
"lag": 8,
"gallbladder": 9,
"esophagus": 10,
"stomach": 11,
"duodenum": 12,
"left kidney": 13
},
"file_ending": ".nii.gz",
"channel_names": {
"0": "CT"
},
"overwrite_image_reader_writer": "NibabelIOWithReorient",
"numTraining": 50
}
```
Afterwards you can run the default nnU-Net planning and preprocessing
```bash
nnUNetv2_plan_and_preprocess -d 311 -c 3d_fullres
```
## Edit the plans files
The generated `nnUNetPlans.json` file needs to be edited to incorporate the custom low resolution setting.
```json
"3d_halfres": {
"inherits_from": "3d_fullres",
"data_identifier": "nnUNetPlans_3d_halfres",
"spacing": [
5,
1.6,
1.6
]
},
"3d_halfiso": {
"inherits_from": "3d_fullres",
"data_identifier": "nnUNetPlans_3d_halfiso",
"spacing": [
2.5,
2.5,
2.5
]
},
```
`3d_halfres` is a configuration with exactly half resolution, used as an ablation of our submission, `3d_halfiso` is the isotropic configuration we submitted as a final solution.
# Model training
Run one of the following commands to train the respective configurations. `3d_halfiso` yielded significantly better results in our experiments as well as on the final test set and is the recommended configuration.
```bash
nnUNetv2_train 311 3d_halfres all
nnUNetv2_train 311 3d_halfiso all
```
# Inference
Our inference is optimized for efficient single scan prediction. For best performance, we strongly recommend running inference using the default `nnUNetv2_predict` command!
Inference using the provided script requires OpenVINO, which can easily be installed via
```bash
pip install openvino
```
To run inference simply run the following commands. `model_folder` is the folder containing the training results, i.e. for example `nnUNetTrainer__nnUNetPlans__3d_halfiso`. `-save_model` needs to be set to precompile the model once using OpenVINO. If not precompiled model exists, the inference script will fail!
```bash
python inference_flare_task1.py -i input_folder -o output_folder -m model_folder [-save_model]
```
================================================
FILE: documentation/competitions/FLARE24/__init__.py
================================================
================================================
FILE: documentation/competitions/Toothfairy2/__init__.py
================================================
================================================
FILE: documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py
================================================
import argparse
import gc
import os
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import Union, Tuple
import nnunetv2
import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from batchgenerators.utilities.file_and_folder_operations import load_json, join
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from torch._dynamo import OptimizedModule
from torch.backends import cudnn
from tqdm import tqdm
class CustomPredictor(nnUNetPredictor):
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None):
torch.set_num_threads(7)
with torch.no_grad():
self.network = self.network.to(self.device)
self.network.eval()
if self.verbose:
print('preprocessing')
preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)
data, _, image_properties = preprocessor.run_case_npy(input_image, None, image_properties,
self.plans_manager,
self.configuration_manager,
self.dataset_json)
data = torch.from_numpy(data)
del input_image
if self.verbose:
print('predicting')
predicted_logits = self.predict_preprocessed_image(data)
if self.verbose: print('Prediction done')
segmentation = self.convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits,
image_properties)
return segmentation
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth'):
"""
This is used when making predictions with a trained model
"""
if use_folds is None:
use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
if isinstance(use_folds, str):
use_folds = [use_folds]
parameters = []
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
map_location=torch.device('cpu'), weights_only=False)
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
parameters.append(join(model_training_output_dir, f'fold_{f}', checkpoint_name))
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)
@torch.inference_mode(mode=True)
def predict_preprocessed_image(self, image):
empty_cache(self.device)
data_device = torch.device('cpu')
predicted_logits_device = torch.device('cpu')
gaussian_device = torch.device('cpu')
compute_device = torch.device('cuda:0')
data, slicer_revert_padding = pad_nd_image(image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)
del image
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
empty_cache(self.device)
data = data.to(data_device)
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=predicted_logits_device)
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=gaussian_device, dtype=torch.float16)
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for p in self.list_of_parameters:
# network weights have to be updated outside autocast!
# we are loading parameters on demand instead of loading them upfront. This reduces memory footprint a lot.
# each set of parameters is only used once on the test set (one image) so run time wise this is almost the
# same
self.network.load_state_dict(torch.load(p, map_location=compute_device)['network_weights'], weights_only=False)
with torch.autocast(self.device.type, enabled=True):
for sl in tqdm(slicers, disable=not self.allow_tqdm):
pred = self._internal_maybe_mirror_and_predict(data[sl][None].to(compute_device))[0].to(
predicted_logits_device)
pred /= (pred.max() / 100)
predicted_logits[sl] += (pred * gaussian)
del pred
empty_cache(self.device)
return predicted_logits
def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits, props):
old = torch.get_num_threads()
torch.set_num_threads(7)
# resample to original shape
spacing_transposed = [props['spacing'][i] for i in self.plans_manager.transpose_forward]
current_spacing = self.configuration_manager.spacing if \
len(self.configuration_manager.spacing) == \
len(props['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *self.configuration_manager.spacing]
predicted_logits = self.configuration_manager.resampling_fn_probabilities(predicted_logits,
props[
'shape_after_cropping_and_before_resampling'],
current_spacing,
[props['spacing'][i] for i in
self.plans_manager.transpose_forward])
segmentation = None
pp = None
try:
with torch.no_grad():
pp = predicted_logits.to('cuda:0')
segmentation = pp.argmax(0).cpu()
del pp
except RuntimeError:
del segmentation, pp
torch.cuda.empty_cache()
segmentation = predicted_logits.argmax(0)
del predicted_logits
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(props['shape_before_cropping'],
dtype=np.uint8 if len(
self.label_manager.foreground_labels) < 255 else np.uint16)
slicer = bounding_box_to_slice(props['bbox_used_for_cropping'])
segmentation_reverted_cropping[slicer] = segmentation
del segmentation
# revert transpose
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(self.plans_manager.transpose_backward)
torch.set_num_threads(old)
return segmentation_reverted_cropping
def predict_semseg(im, prop, semseg_trained_model, semseg_folds):
# initialize predictors
pred_semseg = CustomPredictor(
tile_step_size=0.5,
use_mirroring=True,
use_gaussian=True,
perform_everything_on_device=False,
allow_tqdm=True
)
pred_semseg.initialize_from_trained_model_folder(
semseg_trained_model,
use_folds=semseg_folds,
checkpoint_name='checkpoint_final.pth'
)
semseg_pred = pred_semseg.predict_single_npy_array(
im, prop, None
)
torch.cuda.empty_cache()
gc.collect()
return semseg_pred
def map_labels_to_toothfairy(predicted_seg: np.ndarray) -> np.ndarray:
# Create an array that maps the labels directly
max_label = 42
mapping = np.arange(max_label + 1)
# Define the specific remapping
remapping = {19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 26, 25: 27, 26: 28,
27: 31, 28: 32, 29: 33, 30: 34, 31: 35, 32: 36, 33: 37, 34: 38,
35: 41, 36: 42, 37: 43, 38: 44, 39: 45, 40: 46, 41: 47, 42: 48}
# Apply the remapping
for k, v in remapping.items():
mapping[k] = v
return mapping[predicted_seg]
def postprocess(prediction_npy, vol_per_voxel, verbose: bool = False):
cutoffs = {1: 0.0,
2: 78411.5,
3: 0.0,
4: 0.0,
5: 2800.0,
6: 1216.5,
7: 0.0,
8: 6222.0,
9: 1573.0,
10: 946.0,
11: 0.0,
12: 6783.5,
13: 9469.5,
14: 0.0,
15: 2260.0,
16: 3566.0,
17: 6321.0,
18: 4221.5,
19: 5829.0,
20: 0.0,
21: 0.0,
22: 468.0,
23: 1555.0,
24: 1291.5,
25: 2834.5,
26: 584.5,
27: 0.0,
28: 0.0,
29: 0.0,
30: 0.0,
31: 1935.5,
32: 0.0,
33: 0.0,
34: 6140.0,
35: 0.0,
36: 0.0,
37: 0.0,
38: 2710.0,
39: 0.0,
40: 0.0,
41: 0.0,
42: 970.0}
vol_per_voxel_cutoffs = 0.3 * 0.3 * 0.3
for c in cutoffs.keys():
co = cutoffs[c]
if co > 0:
mask = prediction_npy == c
pred_vol = np.sum(mask) * vol_per_voxel
if 0 < pred_vol < (co * vol_per_voxel_cutoffs):
prediction_npy[mask] = 0
if verbose:
print(
f'removed label {c} because predicted volume of {pred_vol} is less than the cutoff {co * vol_per_voxel_cutoffs}')
return prediction_npy
if __name__ == '__main__':
os.environ['nnUNet_compile'] = 'f'
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_folder', type=Path, default="/input/images/cbct/")
parser.add_argument('-o', '--output_folder', type=Path, default="/output/images/oral-pharyngeal-segmentation/")
parser.add_argument('-sem_mod', '--semseg_trained_model', type=str,
default="/opt/app/_trained_model/semseg_trained_model")
parser.add_argument('--semseg_folds', type=str, nargs='+', default=[0, 1])
args = parser.parse_args()
args.output_folder.mkdir(exist_ok=True, parents=True)
semseg_folds = [i if i == 'all' else int(i) for i in args.semseg_folds]
semseg_trained_model = args.semseg_trained_model
rw = SimpleITKIO()
input_files = list(args.input_folder.glob('*.nii.gz')) + list(args.input_folder.glob('*.mha'))
for input_fname in input_files:
output_fname = args.output_folder / input_fname.name
# we start with the instance seg because we can then start converting that while semseg is being predicted
# load test image
im, prop = rw.read_images([input_fname])
with torch.no_grad():
semseg_pred = predict_semseg(im, prop, semseg_trained_model, semseg_folds)
torch.cuda.empty_cache()
gc.collect()
# now postprocess
semseg_pred = postprocess(semseg_pred, np.prod(prop['spacing']), True)
semseg_pred = map_labels_to_toothfairy(semseg_pred)
# now save
rw.write_seg(semseg_pred, output_fname, prop)
================================================
FILE: documentation/competitions/Toothfairy2/readme.md
================================================
Authors: \
Fabian Isensee*, Yannick Kirchhoff*, Lars Kraemer, Max Rokuss, Constantin Ulrich, Klaus H. Maier-Hein
*: equal contribution
Author Affiliations:\
Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \
Helmholtz Imaging
# Introduction
This document describes our submission to the [Toothfairy2 Challenge](https://toothfairy2.grand-challenge.org/toothfairy2/).
Our model is essentially a nnU-Net ResEnc L with the patch size upscaled to 160x320x320 pixels. We disable left/right
mirroring and train for 1500 instead of the standard 1000 epochs. Training was either done on 2xA100 40GB or one GH200 96GB.
# Dataset Conversion
# Experiment Planning and Preprocessing
Adapt and run the [dataset conversion script](../../../nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py).
This script just converts the mha files to nifti (smaller file size) and removes the unused label ids.
## Extract fingerprint:
`nnUNetv2_extract_fingerprint -d 119 -np 48`
## Run planning:
`nnUNetv2_plan_experiment -d 119 -pl nnUNetPlannerResEncL_torchres`
This planner not only uses the ResEncL configuration but also replaces the default resampling scheme with one that is
faster (but less precise). Since all images in the challenge (train and test) should already have 0.3x0.3x0.3 spacing
resampling is not required. This is just here as a safety measure. The speed is needed at inference time because grand
challenge imposes a limit of 10 minutes per case.
## Edit the plans files
Add the following configuration to the generated plans file:
```json
"3d_fullres_torchres_ps160x320x320_bs2": {
"inherits_from": "3d_fullres",
"data_identifier": "nnUNetPlans_3d_fullres_torchres_ctnorm",
"patch_size": [
160,
320,
320
],
"normalization_schemes": [
"CTNormalization"
],
"architecture": {
"network_class_name": "dynamic_network_architectures.architectures.unet.ResidualEncoderUNet",
"arch_kwargs": {
"n_stages": 7,
"features_per_stage": [
32,
64,
128,
256,
320,
320,
320
],
"conv_op": "torch.nn.modules.conv.Conv3d",
"kernel_sizes": [
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
]
],
"strides": [
[
1,
1,
1
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
1,
2,
2
]
],
"n_blocks_per_stage": [
1,
3,
4,
6,
6,
6,
6
],
"n_conv_per_stage_decoder": [
1,
1,
1,
1,
1,
1
],
"conv_bias": true,
"norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d",
"norm_op_kwargs": {
"eps": 1e-05,
"affine": true
},
"dropout_op": null,
"dropout_op_kwargs": null,
"nonlin": "torch.nn.LeakyReLU",
"nonlin_kwargs": {
"inplace": true
}
},
"_kw_requires_import": [
"conv_op",
"norm_op",
"dropout_op",
"nonlin"
]
}
}
```
Aside from changing the patch size this makes the architecture one stage deeper (one more pooling + res blocks), enabling
it to make effective use of the larger input
# Preprocessing
`nnUNetv2_preprocess -d 119 -c 3d_fullres_torchres_ps160x320x320_bs2 -plans_name nnUNetResEncUNetLPlans_torchres -np 48`
# Training
We train two models on all training cases:
```bash
nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep
nnUNet_results=${nnUNet_results}_2 nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans_torchres -tr nnUNetTrainer_onlyMirror01_1500ep
```
Models are trained from scratch.
Note how in the second line we overwrite the nnUNet_results variable in order to be able to train the same model twice without overwriting the results
We recommend to increase the number of processes used for data augmentation. Otherwise you can run into CPU bottlenecks.
Use `export nnUNet_n_proc_DA=32` or higher (if your system permits!).
# Inference
We ensemble the two models from above. On a technical level we copy the two fold_all folders into one training output
directory and rename them to fold_0 and fold_1. This lets us use nnU-Net's cross-validation ensembling strategy which
is more computationally efficient (needed for time limit on grand-challenge.org).
Run inference with the [inference script](inference_script_semseg_only_customInf2.py)
# Postprocessing
If the prediction of a class on some test case is smaller than the corresponding cutoff size then it is removed
(replaced with background).
Cutoff values were optimized using a five-fold cross-validation on the Toothfairy2 training data. We optimize HD95 and Dice separately.
The final cutoff for each class is then the smaller value between the two metrics. You can find our volume cutoffs in the inference
script as part of our `postprocess` function.
================================================
FILE: documentation/competitions/__init__.py
================================================
================================================
FILE: documentation/convert_msd_dataset.md
================================================
Use `nnUNetv2_convert_MSD_dataset`.
Read `nnUNetv2_convert_MSD_dataset -h` for usage instructions.
================================================
FILE: documentation/dataset_format.md
================================================
# nnU-Net dataset format
The only way to bring your data into nnU-Net is by storing it in a specific format. Due to nnU-Net's roots in the
[Medical Segmentation Decathlon](http://medicaldecathlon.com/) (MSD), its dataset is heavily inspired but has since
diverged (see also [here](#how-to-use-decathlon-datasets)) from the format used in the MSD.
Datasets consist of three components: raw images, corresponding segmentation maps and a dataset.json file specifying
some metadata.
If you are migrating from nnU-Net v1, read [this](#how-to-use-nnu-net-v1-tasks) to convert your existing Tasks.
## What do training cases look like?
Each training case is associated with an identifier = a unique name for that case. This identifier is used by nnU-Net to
connect images with the correct segmentation.
A training case consists of images and their corresponding segmentation.
**Images** is plural because nnU-Net supports arbitrarily many input channels. In order to be as flexible as possible,
nnU-net requires each input channel to be stored in a separate image (with the sole exception being RGB natural
images). So these images could for example be a T1 and a T2 MRI (or whatever else you want). The different input
channels MUST have the same geometry (same shape, spacing (if applicable) etc.) and
must be co-registered (if applicable). Input channels are identified by nnU-Net by their FILE_ENDING: a four-digit integer at the end
of the filename. Image files must therefore follow the following naming convention: {CASE_IDENTIFIER}_{XXXX}.{FILE_ENDING}.
Hereby, XXXX is the 4-digit modality/channel identifier (should be unique for each modality/channel, e.g., “0000” for T1, “0001” for
T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.png, .nii.gz, ...). See below for concrete examples.
The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details).
Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier.
Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the
[road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example).
**Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are
integer maps with each value representing a semantic class. The background must be 0. If there is no background, then
do not use the label 0 for something else! Integer values of your semantic classes must be consecutive (0, 1, 2, 3,
...). Of course, not all labels have to be present in each training case. Segmentations are saved as {CASE_IDENTIFER}.{FILE_ENDING} .
Within a training case, all image geometries (input channels, corresponding segmentation) must match. Between training
cases, they can of course differ. nnU-Net takes care of that.
Important: The input channels must be consistent! Concretely, **all images need the same input channels in the same
order and all input channels have to be present every time**. This is also true for inference!
## Supported file formats
nnU-Net expects the same file format for images and segmentations! These will also be used for inference. For now, it
is thus not possible to train .png and then run inference on .jpg.
One big change in nnU-Net V2 is the support of multiple input file types. Gone are the days of converting everything to .nii.gz!
This is implemented by abstracting the input and output of images + segmentations through `BaseReaderWriter`. nnU-Net
comes with a broad collection of Readers+Writers and you can even add your own to support your data format!
See [here](../nnunetv2/imageio/readme.md).
As a nice bonus, nnU-Net now also natively supports 2D input images and you no longer have to mess around with
conversions to pseudo 3D niftis. Yuck. That was disgusting.
Note that internally (for storing and accessing preprocessed images) nnU-Net will use its own file format, irrespective
of what the raw data was provided in! This is for performance reasons.
By default, the following file formats are supported:
- NaturalImage2DIO: .png, .bmp, .tif
- NibabelIO: .nii.gz, .nrrd, .mha
- NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS!
- SimpleITKIO: .nii.gz, .nrrd, .mha
- Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information,
nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains this information (see
[here](#datasetjson)).
The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK
support more than the three given here. The file endings given here are just the ones we tested!
IMPORTANT: nnU-Net can only be used with file formats that use lossless (or no) compression! Because the file
format is defined for an entire dataset (and not separately for images and segmentations, this could be a todo for
the future), we must ensure that there are no compression artifacts that destroy the segmentation maps. So no .jpg and
the likes!
## Dataset folder structure
Datasets must be located in the `nnUNet_raw` folder (which you either define when installing nnU-Net or export/set every
time you intend to run nnU-Net commands!).
Each segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit
integer, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and
the dataset id is 5. Datasets are stored in the `nnUNet_raw` folder like this:
nnUNet_raw/
├── Dataset001_BrainTumour
├── Dataset002_Heart
├── Dataset003_Liver
├── Dataset004_Hippocampus
├── Dataset005_Prostate
├── ...
Within each dataset folder, the following structure is expected:
Dataset001_BrainTumour/
├── dataset.json
├── imagesTr
├── imagesTs # optional
└── labelsTr
When adding your custom dataset, take a look at the [dataset_conversion](../nnunetv2/dataset_conversion) folder and
pick an id that is not already taken. IDs 001-010 are for the Medical Segmentation Decathlon.
- **imagesTr** contains the images belonging to the training cases. nnU-Net will perform pipeline configuration, training with
cross-validation, as well as finding postprocessing and the best ensemble using this data.
- **imagesTs** (optional) contains the images that belong to the test cases. nnU-Net does not use them! This could just
be a convenient location for you to store these images. Remnant of the Medical Segmentation Decathlon folder structure.
- **labelsTr** contains the images with the ground truth segmentation maps for the training cases.
- **dataset.json** contains metadata of the dataset.
The scheme introduced [above](#what-do-training-cases-look-like) results in the following folder structure. Given
is an example for the first Dataset of the MSD: BrainTumour. This dataset hat four input channels: FLAIR (0000),
T1w (0001), T1gd (0002) and T2w (0003). Note that the imagesTs folder is optional and does not have to be present.
nnUNet_raw/Dataset001_BrainTumour/
├── dataset.json
├── imagesTr
│ ├── BRATS_001_0000.nii.gz
│ ├── BRATS_001_0001.nii.gz
│ ├── BRATS_001_0002.nii.gz
│ ├── BRATS_001_0003.nii.gz
│ ├── BRATS_002_0000.nii.gz
│ ├── BRATS_002_0001.nii.gz
│ ├── BRATS_002_0002.nii.gz
│ ├── BRATS_002_0003.nii.gz
│ ├── ...
├── imagesTs
│ ├── BRATS_485_0000.nii.gz
│ ├── BRATS_485_0001.nii.gz
│ ├── BRATS_485_0002.nii.gz
│ ├── BRATS_485_0003.nii.gz
│ ├── BRATS_486_0000.nii.gz
│ ├── BRATS_486_0001.nii.gz
│ ├── BRATS_486_0002.nii.gz
│ ├── BRATS_486_0003.nii.gz
│ ├── ...
└── labelsTr
├── BRATS_001.nii.gz
├── BRATS_002.nii.gz
├── ...
Here is another example of the second dataset of the MSD, which has only one input channel:
nnUNet_raw/Dataset002_Heart/
├── dataset.json
├── imagesTr
│ ├── la_003_0000.nii.gz
│ ├── la_004_0000.nii.gz
│ ├── ...
├── imagesTs
│ ├── la_001_0000.nii.gz
│ ├── la_002_0000.nii.gz
│ ├── ...
└── labelsTr
├── la_003.nii.gz
├── la_004.nii.gz
├── ...
Remember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also
make sure that all your data is co-registered!
See also [dataset format inference](dataset_format_inference.md)!!
## dataset.json
The dataset.json contains metadata that nnU-Net needs for training. We have greatly reduced the number of required
fields since version 1!
Here is what the dataset.json should look like at the example of the Dataset005_Prostate from the MSD:
{
"channel_names": { # formerly modalities
"0": "T2",
"1": "ADC"
},
"labels": { # THIS IS DIFFERENT NOW!
"background": 0,
"PZ": 1,
"TZ": 2
},
"numTraining": 32,
"file_ending": ".nii.gz"
"overwrite_image_reader_writer": "SimpleITKIO" # optional! If not provided nnU-Net will automatically determine the ReaderWriter
}
The channel_names determine the normalization used by nnU-Net. If a channel is marked as 'CT', then a global
normalization based on the intensities in the foreground pixels will be used. If it is something else, per-channel
z-scoring will be used. Refer to the methods section in [our paper](https://www.nature.com/articles/s41592-020-01008-z)
for more details. nnU-Net v2 introduces a few more normalization schemes to
choose from and allows you to define your own, see [here](explanation_normalization.md) for more information.
Important changes relative to nnU-Net v1:
- "modality" is now called "channel_names" to remove strong bias to medical images
- labels are structured differently (name -> int instead of int -> name). This was needed to support [region-based training](region_based_training.md)
- "file_ending" is added to support different input file types
- "overwrite_image_reader_writer" optional! Can be used to specify a certain (custom) ReaderWriter class that should
be used with this dataset. If not provided, nnU-Net will automatically determine the ReaderWriter
- "regions_class_order" only used in [region-based training](region_based_training.md)
There is a utility with which you can generate the dataset.json automatically. You can find it
[here](../nnunetv2/dataset_conversion/generate_dataset_json.py).
See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation!
As described above, a json file that contains spacing information is required for TIFF files.
An example for a 3D TIFF stack with units corresponding to 7.6 in x and y, 80 in z is:
```
{
"spacing": [7.6, 7.6, 80.0]
}
```
Within the dataset folder, this file (named `cell6.json` in this example) would be placed in the following folders:
nnUNet_raw/Dataset123_Foo/
├── dataset.json
├── imagesTr
│ ├── cell6.json
│ └── cell6_0000.tif
└── labelsTr
├── cell6.json
└── cell6.tif
## How to use nnU-Net v1 Tasks
If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`!
Example for migrating a nnU-Net v1 Task:
```bash
nnUNetv2_convert_old_nnUNet_dataset /media/isensee/raw_data/nnUNet_raw_data_base/nnUNet_raw_data/Task027_ACDC Dataset027_ACDC
```
Use `nnUNetv2_convert_old_nnUNet_dataset -h` for detailed usage instructions.
## How to use decathlon datasets
See [convert_msd_dataset.md](convert_msd_dataset.md)
## How to use 2D data with nnU-Net
2D is now natively supported (yay!). See [here](#supported-file-formats) as well as the example dataset in this
[script](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py).
## How to update an existing dataset
When updating a dataset it is best practice to remove the preprocessed data in `nnUNet_preprocessed/DatasetXXX_NAME`
to ensure a fresh start. Then replace the data in `nnUNet_raw` and rerun `nnUNetv2_plan_and_preprocess`. Optionally,
also remove the results from old trainings.
# Example dataset conversion scripts
In the `dataset_conversion` folder (see [here](../nnunetv2/dataset_conversion)) are multiple example scripts for
converting datasets into nnU-Net format. These scripts cannot be run as they are (you need to open them and change
some paths) but they are excellent examples for you to learn how to convert your own datasets into nnU-Net format.
Just pick the dataset that is closest to yours as a starting point.
The list of dataset conversion scripts is continually updated. If you find that some publicly available dataset is
missing, feel free to open a PR to add it!
================================================
FILE: documentation/dataset_format_inference.md
================================================
# Data format for Inference
Read the documentation on the overall [data format](dataset_format.md) first!
The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly
the same format as in the imagesTr folder**). As before, the filenames must start with a
unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets:
1) Task005_Prostate:
This task has 2 modalities, so the files in the input folder must look like this:
input_folder
├── prostate_03_0000.nii.gz
├── prostate_03_0001.nii.gz
├── prostate_05_0000.nii.gz
├── prostate_05_0001.nii.gz
├── prostate_08_0000.nii.gz
├── prostate_08_0001.nii.gz
├── ...
_0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the
dataset.json), exactly the same as was used for training.
2) Task002_Heart:
imagesTs
├── la_001_0000.nii.gz
├── la_002_0000.nii.gz
├── la_006_0000.nii.gz
├── ...
Task002 only has one modality, so each case only has one _0000.nii.gz file.
The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier).
Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training
(and as was specified in 'file_ending' in the dataset.json)!
================================================
FILE: documentation/explanation_logging.md
================================================
# Logging in nnU-Net v2
## Introduction
Logging in nnU-Net is intentionally simple and centralized in
`nnunetv2/training/logging/nnunet_logger.py`.
The trainer talks to one object, `MetaLogger`, and `MetaLogger` fans out logs to:
- `LocalLogger` (always enabled): the source of truth for training curves, checkpoint logging state, and `progress.png`
- optional external loggers (currently `WandbLogger`)
This keeps training code clean while still allowing external tracking backends.
## Default behaviour
Without any setup, nnU-Net uses only `LocalLogger`.
Per epoch, it stores:
- `mean_fg_dice` and `ema_fg_dice` (EMA is computed automatically)
- `dice_per_class_or_region`
- `train_losses`, `val_losses`
- `lrs`
- `epoch_start_timestamps`, `epoch_end_timestamps`
From these values, `progress.png` is updated in the fold output folder.
On checkpoint save/load, the local logging state is also saved/restored, so curves continue correctly after resume.
## How to enable W&B
1. Install W&B:
```bash
pip install wandb
```
2. Enable the backend via environment variables:
```bash
export nnUNet_wandb_enabled=1
export nnUNet_wandb_project=nnunet
export nnUNet_wandb_mode=online # or offline
```
3. Run training normally:
```bash
nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres 0
```
Notes:
- `nnUNet_wandb_enabled` accepts `0/1` and `false/true` (case-insensitive). Other values raise an error.
- When resuming (`--c`), W&B resume metadata in `fold_x/wandb/latest-run` is reused and duplicate older steps are skipped.
## How to integrate a custom logger
Add a new logger class with the same minimal interface used by `MetaLogger`:
- `update_config(self, config: dict)`
- `log(self, key, value, step: int)`
- `log_summary(self, key, value)`
Example skeleton:
```python
class MyLogger:
def __init__(self, output_folder, resume):
self.output_folder = output_folder
self.resume = resume
def update_config(self, config: dict):
...
def log(self, key, value, step: int):
...
def log_summary(self, key, value):
...
```
Then register it in `MetaLogger.__init__` (for example behind an env var switch), similar to how `WandbLogger` is added.
Important integration detail:
- `MetaLogger.log(...)` always writes to `LocalLogger` first.
- If you introduce a brand-new per-epoch key, also add that key to `LocalLogger.my_fantastic_logging`, otherwise the local assertion will fail.
================================================
FILE: documentation/explanation_normalization.md
================================================
# Intensity normalization in nnU-Net
The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`)
entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on
foreground intensities are supported. However, there have been a few additions as well.
Reminder: The `channel_names` entry typically looks like this:
"channel_names": {
"0": "T2",
"1": "ADC"
},
It has as many entries as there are input channels for the given dataset.
To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization
scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels!
If you enter a channel name that is not in the following list, the default (`zscore`) will be used.
Here is a list of currently available normalization schemes:
- `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the
background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and
99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the
standard deviation. The normalization that is applied is the same for each training case (for this input channel).
The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the
corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT
images and ADC maps.
- `noNorm` : do not perform any normalization at all
- `rescale_to_0_1`: rescale the intensities to [0, 1]
- `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1]
- `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case
**Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If
you deviate from that path, make sure to benchmark whether that actually improves results!
# How to implement custom normalization strategies?
- Head over to nnunetv2/preprocessing/normalization
- implement a new image normalization class by deriving from ImageNormalization
- register it in nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping.
This is where you specify a channel name that should be associated with it
- use it by specifying the correct channel_name
Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme
that gets multiple channels as input to be used jointly!
================================================
FILE: documentation/explanation_plans_files.md
================================================
# Modifying the nnU-Net Configurations
nnU-Net provides unprecedented out-of-the-box segmentation performance for essentially any dataset we have evaluated
it on. That said, there is always room for improvements. A fool-proof strategy for squeezing out the last bit of
performance is to start with the default nnU-Net, and then further tune it manually to a concrete dataset at hand.
**This guide is about changes to the nnU-Net configuration you can make via the plans files. It does not cover code
extensions of nnU-Net. For that, take a look [here](extending_nnunet.md)**
In nnU-Net V2, plans files are SO MUCH MORE powerful than they were in v1. There are a lot more knobs that you can
turn without resorting to hacky solutions or even having to touch the nnU-Net code at all! And as an added bonus:
plans files are now also .json files and no longer require users to fiddle with pickle. Just open them in your text
editor of choice!
If overwhelmed, look at our [Examples](#examples)!
# plans.json structure
Plans have global and local settings. Global settings are applied to all configurations in that plans file while
local settings are attached to a specific configuration.
## Global settings
- `foreground_intensity_properties_by_modality`: Intensity statistics of the foreground regions (all labels except
background and ignore label), computed over all training cases. Used by [CT normalization scheme](explanation_normalization.md).
- `image_reader_writer`: Name of the image reader/writer class that should be used with this dataset. You might want
to change this if, for example, you would like to run inference with files that have a different file format. The
class that is named here must be located in nnunetv2.imageio!
- `label_manager`: The name of the class that does label handling. Take a look at
nnunetv2.utilities.label_handling.LabelManager to see what it does. If you decide to change it, place your version
in nnunetv2.utilities.label_handling!
- `transpose_forward`: nnU-Net transposes the input data so that the axes with the highest resolution (lowest spacing)
come last. This is because the 2D U-Net operates on the trailing dimensions (more efficient slicing due to internal
memory layout of arrays). Future work might move this setting to affect only individual configurations.
- transpose_backward is what numpy.transpose gets as new axis ordering.
- `transpose_backward`: the axis ordering that inverts "transpose_forward"
- \[`original_median_shape_after_transp`\]: just here for your information
- \[`original_median_spacing_after_transp`\]: just here for your information
- \[`plans_name`\]: do not change. Used internally
- \[`experiment_planner_used`\]: just here as metadata so that we know what planner originally generated this file
- \[`dataset_name`\]: do not change. This is the dataset these plans are intended for
## Local settings
Plans also have a `configurations` key in which the actual configurations are stored. `configurations` are again a
dictionary, where the keys are the configuration names and the values are the local settings for each configuration.
To better understand the components describing the network topology in our plans files, please read section 6.2
in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf)
(page 13) of our paper!
Local settings:
- `spacing`: the target spacing used in this configuration
- `patch_size`: the patch size used for training this configuration
- `data_identifier`: the preprocessed data for this configuration will be saved in
nnUNet_preprocessed/DATASET_NAME/_data_identifier_. If you add a new configuration, remember to set a unique
data_identifier in order to not create conflicts with other configurations (unless you plan to reuse the data from
another configuration, for example as is done in the cascade)
- `batch_size`: batch size used for training
- `batch_dice`: whether to use batch dice (pretend all samples in the batch are one image, compute dice loss over that)
or not (each sample in the batch is a separate image, compute dice loss for each sample and average over samples)
- `preprocessor_name`: Name of the preprocessor class used for running preprocessing. Class must be located in
nnunetv2.preprocessing.preprocessors
- `use_mask_for_norm`: whether to use the nonzero mask for normalization or not (relevant for BraTS and the like,
probably False for all other datasets). Interacts with ImageNormalization class
- `normalization_schemes`: mapping of channel identifier to ImageNormalization class name. ImageNormalization
classes must be located in nnunetv2.preprocessing.normalization. Also see [here](explanation_normalization.md)
- `resampling_fn_data`: name of resampling function to be used for resizing image data. resampling function must be
callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling
- `resampling_fn_data_kwargs`: kwargs for resampling_fn_data
- `resampling_fn_probabilities`: name of resampling function to be used for resizing predicted class probabilities/logits.
resampling function must be `callable(data: Union[np.ndarray, torch.Tensor], current_spacing, new_spacing, **kwargs)`. It must be located in
nnunetv2.preprocessing.resampling
- `resampling_fn_probabilities_kwargs`: kwargs for resampling_fn_probabilities
- `resampling_fn_seg`: name of resampling function to be used for resizing segmentation maps (integer: 0, 1, 2, 3, etc).
resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in
nnunetv2.preprocessing.resampling
- `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg
- `network_arch_class_name`: UNet class name, can be used to integrate custom dynamic architectures
- `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features
are doubled with each downsampling
- `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to
prevent parameters from exploding too much.
- `conv_kernel_sizes`: the convolutional kernel sizes used by nnU-Net in each stage of the encoder. The decoder
mirrors the encoder and is therefore not explicitly listed here! The list is as long as `n_conv_per_stage_encoder` has
entries
- `n_conv_per_stage_encoder`: number of convolutions used per stage (=at a feature map resolution in the encoder) in the encoder.
Default is 2. The list has as many entries as the encoder has stages
- `n_conv_per_stage_decoder`: number of convolutions used per stage in the decoder. Also see `n_conv_per_stage_encoder`
- `num_pool_per_axis`: number of times each of the spatial axes is pooled in the network. Needed to know how to pad
image sizes during inference (num_pool = 5 means input must be divisible by 2**5=32)
- `pool_op_kernel_sizes`: the pooling kernel sizes (and at the same time strides) for each stage of the encoder
- \[`median_image_size_in_voxels`\]: the median size of the images of the training set at the current target spacing.
Do not modify this as this is not used. It is just here for your information.
Special local settings:
- `inherits_from`: configurations can inherit from each other. This makes it easy to add new configurations that only
differ in a few local settings from another. If using this, remember to set a new `data_identifier` (if needed)!
- `previous_stage`: if this configuration is part of a cascade, we need to know what the previous stage (for example
the low resolution configuration) was. This needs to be specified here.
- `next_stage`: if this configuration is part of a cascade, we need to know what possible subsequent stages are! This
is because we need to export predictions in the correct spacing when running the validation. `next_stage` can either
be a string or a list of strings
# Examples
## Increasing the batch size for large datasets
If your dataset is large the training can benefit from larger batch_sizes. To do this, simply create a new
configuration in the `configurations` dict
"configurations": {
"3d_fullres_bs40": {
"inherits_from": "3d_fullres",
"batch_size": 40
}
}
No need to change the data_identifier. `3d_fullres_bs40` will just use the preprocessed data from `3d_fullres`.
No need to rerun `nnUNetv2_preprocess` because we can use already existing data (if available) from `3d_fullres`.
## Using custom preprocessors
If you would like to use a different preprocessor class then this can be specified as follows:
"configurations": {
"3d_fullres_my_preprocesor": {
"inherits_from": "3d_fullres",
"preprocessor_name": MY_PREPROCESSOR,
"data_identifier": "3d_fullres_my_preprocesor"
}
}
You need to run preprocessing for this new configuration:
`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_preprocesor` because it changes the preprocessing. Remember to
set a unique `data_identifier` whenever you make modifications to the preprocessed data!
## Change target spacing
"configurations": {
"3d_fullres_my_spacing": {
"inherits_from": "3d_fullres",
"spacing": [X, Y, Z],
"data_identifier": "3d_fullres_my_spacing"
}
}
You need to run preprocessing for this new configuration:
`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_spacing` because it changes the preprocessing. Remember to
set a unique `data_identifier` whenever you make modifications to the preprocessed data!
## Adding a cascade to a dataset where it does not exist
Hippocampus is small. It doesn't have a cascade. It also doesn't really make sense to add a cascade here but hey for
the sake of demonstration we can do that.
We change the following things here:
- `spacing`: The lowres stage should operate at a lower resolution
- we modify the `median_image_size_in_voxels` entry as a guide for what original image sizes we deal with
- we set some patch size that is inspired by `median_image_size_in_voxels`
- we need to remember that the patch size must be divisible by 2**num_pool in each axis!
- network parameters such as kernel sizes, pooling operations are changed accordingly
- we need to specify the name of the next stage
- we need to add the highres stage
This is how this would look like (comparisons with 3d_fullres given as reference):
"configurations": {
"3d_lowres": {
"inherits_from": "3d_fullres",
"data_identifier": "3d_lowres"
"spacing": [2.0, 2.0, 2.0], # from [1.0, 1.0, 1.0] in 3d_fullres
"median_image_size_in_voxels": [18, 25, 18], # from [36, 50, 35]
"patch_size": [20, 28, 20], # from [40, 56, 40]
"n_conv_per_stage_encoder": [2, 2, 2], # one less entry than 3d_fullres ([2, 2, 2, 2])
"n_conv_per_stage_decoder": [2, 2], # one less entry than 3d_fullres
"num_pool_per_axis": [2, 2, 2], # one less pooling than 3d_fullres in each dimension (3d_fullres: [3, 3, 3])
"pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], # one less [2, 2, 2]
"conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], # one less [3, 3, 3]
"next_stage": "3d_cascade_fullres" # name of the next stage in the cascade
},
"3d_cascade_fullres": { # does not need a data_identifier because we can use the data of 3d_fullres
"inherits_from": "3d_fullres",
"previous_stage": "3d_lowres" # name of the previous stage
}
}
To better understand the components describing the network topology in our plans files, please read section 6.2
in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf)
(page 13) of our paper!
================================================
FILE: documentation/extending_nnunet.md
================================================
# Extending nnU-Net
We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an
extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position
in the repository where the thing you intend to change is implemented and start working your way through the code from
there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the
necessary modifications!
Here are some things you might want to read before you start:
- Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding
preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)!
- [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend!
- Manual data splits can be defined as described [here](manual_data_splits.md)
- You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md)
- Read about our support for [region-based training](region_based_training.md)
- If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need
to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and
implements the necessary changes. Head over to our [trainer classes folder](../nnunetv2/training/nnUNetTrainer) for
inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer
are structured similarly to PyTorch lightning trainers, this should also make things easier!
- Integrating new network architectures can be done in two ways:
- Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function.
Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision`
as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any
nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that!
- The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class
used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether
certain patch sizes and
topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new
class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your
custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by
your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`.
It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation!
- Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated
into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer
classes and ensure support for all of these features! Or raise `NotImplementedError`)
[//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md))
================================================
FILE: documentation/how_to_use_nnunet.md
================================================
## **2024-04-18 UPDATE: New residual encoder UNet presets available!**
The recommended nnU-Net presets have changed! See [here](resenc_presets.md) how to unlock them!
## How to run nnU-Net on a new dataset
Given some dataset, nnU-Net fully automatically configures an entire segmentation pipeline that matches its properties.
nnU-Net covers the entire pipeline, from preprocessing to model configuration, model training, postprocessing
all the way to ensembling. After running nnU-Net, the trained model(s) can be applied to the test cases for inference.
### Dataset Format
nnU-Net expects datasets in a structured format. This format is inspired by the data structure of
the [Medical Segmentation Decthlon](http://medicaldecathlon.com/). Please read
[this](dataset_format.md) for information on how to set up datasets to be compatible with nnU-Net.
**Since version 2 we support multiple image file formats (.nii.gz, .png, .tif, ...)! Read the dataset_format
documentation to learn more!**
**Datasets from nnU-Net v1 can be converted to V2 by running `nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER
OUTPUT_DATASET_NAME`.** Remember that v2 calls datasets DatasetXXX_Name (not Task) where XXX is a 3-digit number.
Please provide the **path** to the old task, not just the Task name. nnU-Net V2 doesn't know where v1 tasks were!
### Experiment planning and preprocessing
Given a new dataset, nnU-Net will extract a dataset fingerprint (a set of dataset-specific properties such as
image sizes, voxel spacings, intensity information etc). This information is used to design three U-Net configurations.
Each of these pipelines operates on its own preprocessed version of the dataset.
The easiest way to run fingerprint extraction, experiment planning and preprocessing is to use:
```bash
nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity
```
Where `DATASET_ID` is the dataset id (duh). We recommend `--verify_dataset_integrity` whenever it's the first time
you run this command. This will check for some of the most common error sources! Fingerprint extraction and
preprocessing show a progress bar by default. Use `--no_pbar` if you want to suppress it, for example in
non-interactive cluster environments.
You can also process several datasets at once by giving `-d 1 2 3 [...]`. If you already know what U-Net configuration
you need you can also specify that with `-c 3d_fullres` (make sure to adapt -np in this case!). For more information
about all the options available to you please run `nnUNetv2_plan_and_preprocess -h`.
nnUNetv2_plan_and_preprocess will create a new subfolder in your nnUNet_preprocessed folder named after the dataset.
Once the command is completed there will be a dataset_fingerprint.json file as well as a nnUNetPlans.json file for you to look at
(in case you are interested!). There will also be subfolders containing the preprocessed data for your UNet configurations.
[Optional]
If you prefer to keep things separate, you can also use `nnUNetv2_extract_fingerprint`, `nnUNetv2_plan_experiment`
and `nnUNetv2_preprocess` (in that order). `nnUNetv2_extract_fingerprint` and `nnUNetv2_preprocess` also support
`--no_pbar`.
### Model training
#### Overview
You pick which configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) should be trained! If you have no idea
what performs best on your data, just run all of them and let nnU-Net identify the best one. It's up to you!
nnU-Net trains all configurations in a 5-fold cross-validation over the training cases. This is 1) needed so that
nnU-Net can estimate the performance of each configuration and tell you which one should be used for your
segmentation problem and 2) a natural way of obtaining a good model ensemble (average the output of these 5 models
for prediction) to boost performance.
You can influence the splits nnU-Net uses for 5-fold cross-validation (see [here](manual_data_splits.md)). If you
prefer to train a single model on all training cases, this is also possible (see below).
**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net
cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net
already covers a large part of the input images.**
Training models is done with the `nnUNetv2_train` command. The general structure of the command is:
```bash
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
```
UNET_CONFIGURATION is a string that identifies the requested U-Net configuration (defaults: 2d, 3d_fullres, 3d_lowres,
3d_cascade_lowres). DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of
the 5-fold-cross-validation is trained.
nnU-Net stores a checkpoint every 50 epochs. If you need to continue a previous training, just add a `--c` to the
training command.
IMPORTANT: If you plan to use `nnUNetv2_find_best_configuration` (see below) add the `--npz` flag. This makes
nnU-Net save the softmax outputs during the final validation. They are needed for that. Exported softmax
predictions are very large and therefore can take up a lot of disk space, which is why this is not enabled by default.
If you ran initially without the `--npz` flag but now require the softmax predictions, simply rerun the validation with:
```bash
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz
```
You can specify the device nnU-net should use by using `-device DEVICE`. DEVICE can only be cpu, cuda or mps. If
you have multiple GPUs, please select the gpu id using `CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...]` (requires device to be cuda).
See `nnUNetv2_train -h` for additional options.
### 2D U-Net
For FOLD in [0, 1, 2, 3, 4], run:
```bash
nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD [--npz]
```
### 3D full resolution U-Net
For FOLD in [0, 1, 2, 3, 4], run:
```bash
nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [--npz]
```
### 3D U-Net cascade
#### 3D low resolution U-Net
For FOLD in [0, 1, 2, 3, 4], run:
```bash
nnUNetv2_train DATASET_NAME_OR_ID 3d_lowres FOLD [--npz]
```
#### 3D full resolution U-Net
For FOLD in [0, 1, 2, 3, 4], run:
```bash
nnUNetv2_train DATASET_NAME_OR_ID 3d_cascade_fullres FOLD [--npz]
```
**Note that the 3D full resolution U-Net of the cascade requires the five folds of the low resolution U-Net to be
completed!**
The trained models will be written to the nnUNet_results folder. Each training obtains an automatically generated
output folder name:
nnUNet_results/DatasetXXX_MYNAME/TRAINER_CLASS_NAME__PLANS_NAME__CONFIGURATION/FOLD
For Dataset002_Heart (from the MSD), for example, this looks like this:
nnUNet_results/
├── Dataset002_Heart
│── nnUNetTrainer__nnUNetPlans__2d
│ ├── fold_0
│ ├── fold_1
│ ├── fold_2
│ ├── fold_3
│ ├── fold_4
│ ├── dataset.json
│ ├── dataset_fingerprint.json
│ └── plans.json
└── nnUNetTrainer__nnUNetPlans__3d_fullres
├── fold_0
├── fold_1
├── fold_2
├── fold_3
├── fold_4
├── dataset.json
├── dataset_fingerprint.json
└── plans.json
Note that 3d_lowres and 3d_cascade_fullres do not exist here because this dataset did not trigger the cascade. In each
model training output folder (each of the fold_x folder), the following files will be created:
- debug.json: Contains a summary of blueprint and inferred parameters used for training this model as well as a
bunch of additional stuff. Not easy to read, but very useful for debugging ;-)
- checkpoint_best.pth: checkpoint files of the best model identified during training. Not used right now unless you
explicitly tell nnU-Net to use it.
- checkpoint_final.pth: checkpoint file of the final model (after training has ended). This is what is used for both
validation and inference.
- network_architecture.pdf (only if hiddenlayer is installed!): a pdf document with a figure of the network architecture in it.
- progress.png: Shows losses, pseudo dice, learning rate and epoch times ofer the course of the training. At the top is
a plot of the training (blue) and validation (red) loss during training. Also shows an approximation of
the dice (green) as well as a moving average of it (dotted green line). This approximation is the average Dice score
of the foreground classes. **It needs to be taken with a big (!)
grain of salt** because it is computed on randomly drawn patches from the validation
data at the end of each epoch, and the aggregation of TP, FP and FN for the Dice computation treats the patches as if
they all originate from the same volume ('global Dice'; we do not compute a Dice for each validation case and then
average over all cases but pretend that there is only one validation case from which we sample patches). The reason for
this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model
is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training.
- validation: in this folder are the predicted validation cases after the training has finished. The summary.json file in here
contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then
the compressed softmax outputs (saved as .npz files) are in here as well.
During training it is often useful to watch the progress. We therefore recommend that you have a look at the generated
progress.png when running the first training. It will be updated after each epoch.
Training times largely depend on the GPU. The smallest GPU we recommend for training is the Nvidia RTX 2080ti. With
that all network trainings take less than 2 days. Refer to our [benchmarks](benchmarking.md) to see if your system is
performing as expected.
### Using multiple GPUs for training
If multiple GPUs are at your disposal, the best way of using them is to train multiple nnU-Net trainings at once, one
on each GPU. This is because data parallelism never scales perfectly linearly, especially not with small networks such
as the ones used by nnU-Net.
Example:
```bash
CUDA_VISIBLE_DEVICES=0 nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] & # train on GPU 0
CUDA_VISIBLE_DEVICES=1 nnUNetv2_train DATASET_NAME_OR_ID 2d 1 [--npz] & # train on GPU 1
CUDA_VISIBLE_DEVICES=2 nnUNetv2_train DATASET_NAME_OR_ID 2d 2 [--npz] & # train on GPU 2
CUDA_VISIBLE_DEVICES=3 nnUNetv2_train DATASET_NAME_OR_ID 2d 3 [--npz] & # train on GPU 3
CUDA_VISIBLE_DEVICES=4 nnUNetv2_train DATASET_NAME_OR_ID 2d 4 [--npz] & # train on GPU 4
...
wait
```
**Important: The first time a training is run nnU-Net will extract the preprocessed data into uncompressed numpy
arrays for speed reasons! This operation must be completed before starting more than one training of the same
configuration! Wait with starting subsequent folds until the first training is using the GPU! Depending on the
dataset size and your System this should only take a couple of minutes at most.**
If you insist on running DDP multi-GPU training, we got you covered:
`nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] -num_gpus X`
Again, note that this will be slower than running separate training on separate GPUs. DDP only makes sense if you have
manually interfered with the nnU-Net configuration and are training larger models with larger patch and/or batch sizes!
Important when using `-num_gpus`:
1) If you train using, say, 2 GPUs but have more GPUs in the system you need to specify which GPUs should be used via
CUDA_VISIBLE_DEVICES=0,1 (or whatever your ids are).
2) You cannot specify more GPUs than you have samples in your minibatches. If the batch size is 2, 2 GPUs is the maximum!
3) Make sure your batch size is divisible by the numbers of GPUs you use or you will not make good use of your hardware.
In contrast to the old nnU-Net, DDP is now completely hassle free. Enjoy!
### Automatically determine the best configuration
Once the desired configurations were trained (full cross-validation) you can tell nnU-Net to automatically identify
the best combination for you:
```commandline
nnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS
```
`CONFIGURATIONS` hereby is the list of configurations you would like to explore. Per default, ensembling is enabled
meaning that nnU-Net will generate all possible combinations of ensembles (2 configurations per ensemble). This requires
the .npz files containing the predicted probabilities of the validation set to be present (use `nnUNetv2_train` with
`--npz` flag, see above). You can disable ensembling by setting the `--disable_ensembling` flag.
See `nnUNetv2_find_best_configuration -h` for more options.
nnUNetv2_find_best_configuration will also automatically determine the postprocessing that should be used.
Postprocessing in nnU-Net only considers the removal of all but the largest component in the prediction (once for
foreground vs background and once for each label/region).
Once completed, the command will print to your console exactly what commands you need to run to make predictions. It
will also create two files in the `nnUNet_results/DATASET_NAME` folder for you to inspect:
- `inference_instructions.txt` again contains the exact commands you need to use for predictions
- `inference_information.json` can be inspected to see the performance of all configurations and ensembles, as well
as the effect of the postprocessing plus some debug information.
### Run inference
Remember that the data located in the input folder must have the file endings as the dataset you trained the model on
and must adhere to the nnU-Net naming scheme for image files (see [dataset format](dataset_format.md) and
[inference data format](dataset_format_inference.md)!)
`nnUNetv2_find_best_configuration` (see above) will print a string to the terminal with the inference commands you need to use.
The easiest way to run inference is to simply use these commands.
If you wish to manually specify the configuration(s) used for inference, use the following commands:
#### Run prediction
For each of the desired configurations, run:
```
nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities
```
Only specify `--save_probabilities` if you intend to use ensembling. `--save_probabilities` will make the command save the predicted
probabilities alongside of the predicted segmentation masks requiring a lot of disk space.
Please select a separate `OUTPUT_FOLDER` for each configuration!
Note that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very
strongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference.
If you wish to make predictions with a single model, train the `all` fold and specify it in `nnUNetv2_predict`
with `-f all`
#### Ensembling multiple configurations
If you wish to ensemble multiple predictions (typically form different configurations), you can do so with the following command:
```bash
nnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES
```
You can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were
generated by `nnUNetv2_predict`. Again, `nnUNetv2_ensemble -h` will tell you more about additional options.
#### Apply postprocessing
Finally, apply the previously determined postprocessing to the (ensembled) predictions:
```commandline
nnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE
```
`nnUNetv2_find_best_configuration` (or its generated `inference_instructions.txt` file) will tell you where to find
the postprocessing file. If not you can just look for it in your results folder (it's creatively named
`postprocessing.pkl`). If your source folder is from an ensemble, you also need to specify a `-plans_json` file and
a `-dataset_json` file that should be used (for single configuration predictions these are automatically copied
from the respective training). You can pick these files from any of the ensemble members.
## How to run inference with pretrained models
See [here](run_inference_with_pretrained_models.md)
## How to Deploy and Run Inference with YOUR Pretrained Models
To facilitate the use of pretrained models on a different computer for inference purposes, follow these streamlined steps:
1. Exporting the Model: Utilize the `nnUNetv2_export_model_to_zip` function to package your trained model into a .zip file. This file will contain all necessary model files.
2. Transferring the Model: Transfer the .zip file to the target computer where inference will be performed.
3. Importing the Model: On the new PC, use the `nnUNetv2_install_pretrained_model_from_zip` to load the pretrained model from the .zip file.
Please note that both computers must have nnU-Net installed along with all its dependencies to ensure compatibility and functionality of the model.
[//]: # (## Examples)
[//]: # ()
[//]: # (To get you started we compiled two simple to follow examples:)
[//]: # (- run a training with the 3d full resolution U-Net on the Hippocampus dataset. See [here](documentation/training_example_Hippocampus.md).)
[//]: # (- run inference with nnU-Net's pretrained models on the Prostate dataset. See [here](documentation/inference_example_Prostate.md).)
[//]: # ()
[//]: # (Usability not good enough? Let us know!)
================================================
FILE: documentation/ignore_label.md
================================================
# Ignore Label
The _ignore label_ can be used to mark regions that should be ignored by nnU-Net. This can be used to
learn from images where only sparse annotations are available, for example in the form of scribbles or a limited
amount of annotated slices. Internally, this is accomplished by using partial losses, i.e. losses that are only
computed on annotated pixels while ignoring the rest. Take a look at our
[`DC_and_BCE_loss` loss](../nnunetv2/training/loss/compound_losses.py) to see how this is done.
During inference (validation and prediction), nnU-Net will always predict dense segmentations. Metric computation in
validation is of course only done on annotated pixels.
Using sparse annotations can be used to train a model for application to new, unseen images or to autocomplete the
provided training cases given the sparse labels.
(See our [paper](https://arxiv.org/abs/2403.12834) for more information)
Typical use-cases for the ignore label are:
- Save annotation time through sparse annotation schemes
- Annotation of all or a subset of slices with scribbles (Scribble Supervision)
- Dense annotation of a subset of slices
- Dense annotation of chosen patches/cubes within an image
- Coarsly masking out faulty segmentations in the reference segmentations
- Masking areas for other reasons
If you are using nnU-Net's ignore label, please cite the following paper in addition to the original nnU-net paper:
```
Gotkowski, K., Lüth, C., Jäger, P. F., Ziegler, S., Krämer, L., Denner, S., Xiao, S., Disch, N., H., K., & Isensee, F.
(2024). Embarrassingly Simple Scribble Supervision for 3D Medical Segmentation. ArXiv. /abs/2403.12834
```
## Usecases
### Scribble Supervision
Scribbles are free-form drawings to coarsly annotate an image. As we have demonstrated in our recent [paper](https://arxiv.org/abs/2403.12834), nnU-Net's partial loss implementation enables state-of-the-art learning from partially annotated data and even surpasses many purpose-built methods for learning from scribbles. As a starting point, for each image slice and each class (including background), an interior and a border scribble should be generated:
- Interior Scribble: A scribble placed randomly within the class interior of a class instance
- Border Scribble: A scribble roughly delineating a small part of the class border of a class instance
An example of such scribble annotations is depicted in Figure 1 and an animation in Animation 1.
Depending on the availability of data and their variability it is also possible to only annotated a subset of selected slices.
Figure 1: Examples of segmentation types with (A) depicting a dense segmentation and (B) a scribble segmentation.
Animation 1: Depiction of a dense segmentation and a scribble annotation. Background scribbles have been excluded for better visualization.
### Dense annotation of a subset of slices
Another form of sparse annotation is the dense annotation of a subset of slices. These slices should be selected by the user either randomly, based on visual class variation between slices or in an active learning setting. An example with only 10% of slices annotated is depicted in Figure 2.
Figure 2: Examples of a dense annotation of a subset of slices. The ignored areas are shown in red.
## Usage within nnU-Net
Usage of the ignore label in nnU-Net is straightforward and only requires the definition of an _ignore_ label in the _dataset.json_.
This ignore label MUST be the highest integer label value in the segmentation. Exemplary, given the classes background and two foreground classes, then the ignore label must have the integer 3. The ignore label must be named _ignore_ in the _dataset.json_. Given the BraTS dataset as an example the labels dict of the _dataset.json_ must look like this:
```python
...
"labels": {
"background": 0,
"edema": 1,
"non_enhancing_and_necrosis": 2,
"enhancing_tumor": 3,
"ignore": 4
},
...
```
Of course, the ignore label is compatible with [region-based training](region_based_training.md):
```python
...
"labels": {
"background": 0,
"whole_tumor": (1, 2, 3),
"tumor_core": (2, 3),
"enhancing_tumor": 3, # or (3, )
"ignore": 4
},
"regions_class_order": (1, 2, 3), # don't declare ignore label here! It is not predicted
...
```
Then use the dataset as you would any other.
Remember that nnU-Net runs a cross-validation. Thus, it will also evaluate on your partially annotated data. This
will of course work! If you wish to compare different sparse annotation strategies (through simulations for example),
we recommend evaluating on densely annotated images by running inference and then using `nnUNetv2_evaluate_folder` or
`nnUNetv2_evaluate_simple`.
================================================
FILE: documentation/installation_instructions.md
================================================
# System requirements
## Operating System
nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box!
## Hardware requirements
We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D
convolutions, so you might have to use the CPU on those devices).
### Hardware requirements for Training
We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2).
For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is
required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads)
are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of
input channels and target structures. Plus, the faster the GPU, the better the CPU should be!
### Hardware Requirements for inference
Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However,
inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least
4 GB of available (unused) VRAM.
### Example hardware configurations
Example workstation configurations for training:
- CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well.
- GPU: RTX 3090 or RTX 4090
- RAM: 64GB
- Storage: SSD (M.2 PCIe Gen 3 or better!)
Example Server configuration for training:
- CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100!
- GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power)
- RAM: 1 TB
- Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage
(nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously)
### Setting the correct number of Workers for data augmentation (training only)
Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your
CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by
setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`).
Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for
RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes.
# Installation instructions
We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to
compile PyTorch from source (see below), you will need to use conda instead of pip.
Use a recent version of Python! 3.9 or newer is guaranteed to work!
**nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.**
1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please
install the latest version with support for your hardware (cuda, mps, cpu).
**DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider
[compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!).
2) Install nnU-Net depending on your use case:
1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running
**inference with pretrained models**:
```pip install nnunetv2```
2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you
can modify it as needed):
```bash
git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .
```
3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to
set a few environment variables. Please follow the instructions [here](setting_up_paths.md).
4) (OPTIONAL) Install [hiddenlayer](https://github.com/waleedka/hiddenlayer). hiddenlayer enables nnU-net to generate
plots of the network topologies it generates (see [Model training](how_to_use_nnunet.md#model-training)).
To install hiddenlayer,
run the following command:
```bash
pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git
```
Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net
pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for
easy identification.
Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this
environment must be activated when executing the commands. You can see what scripts/functions are executed by
checking the project.scripts in the [pyproject.toml](../pyproject.toml) file.
All nnU-Net commands have a `-h` option which gives information on how to use them.
================================================
FILE: documentation/manual_data_splits.md
================================================
# How to generate custom splits in nnU-Net
Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold
cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification.
Fear not, for nnU-Net has got you covered (it really can do anything <3).
The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for
existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split,
manually creating a split file that will then be recognized and used is the way to go!
The split file is located in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first
populate this folder by running `nnUNetv2_plan_and_preproccess`.
Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it
contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are
again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002
file as an example:
```commandline
In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json
In [2]: splits = load_json('splits_final.json')
In [3]: len(splits)
Out[3]: 5
In [4]: splits[0].keys()
Out[4]: dict_keys(['train', 'val'])
In [5]: len(splits[0]['train'])
Out[5]: 16
In [6]: len(splits[0]['val'])
Out[6]: 4
In [7]: print(splits[0])
{'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'],
'val': ['la_007', 'la_016', 'la_021', 'la_024']}
```
If you are still not sure what splits are supposed to look like, simply download some reference dataset from the
[Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect
the .json file with your text editor of choice!
In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as
`splits_final.json` in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual.
================================================
FILE: documentation/pretraining_and_finetuning.md
================================================
# Pretraining with nnU-Net
## Intro
So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset
and then use the final network weights as initialization for your target dataset.
As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a
result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not
possible to simply take the network weights from some dataset and then reuse them for another.
Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and
how the resulting weights can then be used for initialization.
### Terminology
Throughout this README we use the following terminology:
- `pretraining dataset` is the dataset you intend to run the pretraining on
- `finetuning dataset` is the dataset you are interested in; the one you wish to fine tune on
## Training on the pretraining dataset
In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are
only interested in the finetuning dataset, we first need to run experiment planning (and preprocessing) for it:
```bash
nnUNetv2_plan_and_preprocess -d FINETUNING_DATASET
```
Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available:
```bash
nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET
```
Now we can take the plans from the finetuning dataset and transfer it to the pretraining dataset:
```bash
nnUNetv2_move_plans_between_datasets -s FINETUNING_DATASET -t PRETRAINING_DATASET -sp FINETUNING_PLANS_IDENTIFIER -tp PRETRAINING_PLANS_IDENTIFIER
```
`FINETUNING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in
nnUNetv2_plan_and_preprocess. For `PRETRAINING_PLANS_IDENTIFIER` we recommend you set something custom in order to not
overwrite default plans.
Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but
also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not
work well (but it could, depending on the schemes!).
Note on CT normalization: Yes, also the clip values, mean and std are transferred!
Now you can run the preprocessing on the pretraining dataset:
```bash
nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name PRETRAINING_PLANS_IDENTIFIER
```
And run the training as usual:
```bash
nnUNetv2_train PRETRAINING_DATASET CONFIG all -p PRETRAINING_PLANS_IDENTIFIER
```
Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data.
## Using pretrained weights
Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model:
```bash
nnUNetv2_train FINETUNING_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT
```
Specify the checkpoint in PATH_TO_CHECKPOINT.
When loading pretrained weights, all layers except the segmentation layers will be used!
So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use
nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation
heads or shorter training time.
================================================
FILE: documentation/region_based_training.md
================================================
# Region-based training
## What is this about?
In some segmentation tasks, most prominently the
[Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric
will be computed) are different from the labels provided in the training data. This is the case because for some
clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the
individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor).
The figure shows an example BraTS case along with label-based representation of the task (top) and region-based
representation (bottom). The challenge evaluation is done on the regions. As we have shown in our
[BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those
overlapping areas over the individual labels yields better scoring models!
## What can nnU-Net do?
nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For
some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training.
Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs +
substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be
segmented. The first target region could thus be the entire liver (including the substructures), while the remaining
targets are the individual substructues.
Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output!
Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps
for nnU-Net to work (see below how this is done).
## How do you use it?
When declaring the labels in the `dataset.json` file, BraTS would typically look like this:
```python
...
"labels": {
"background": 0,
"edema": 1,
"non_enhancing_and_necrosis": 2,
"enhancing_tumor": 3
},
...
```
(we use different int values than the challenge because nnU-Net needs consecutive integers!)
This representation corresponds to the upper row in the figure above.
For region-based training, the labels need to be changed to the following:
```python
...
"labels": {
"background": 0,
"whole_tumor": [1, 2, 3],
"tumor_core": [2, 3],
"enhancing_tumor": 3 # or [3]
},
"regions_class_order": [1, 2, 3],
...
```
This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is
required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map.
It essentially just tells nnU-Net what labels to place for which region in what order. The length of the
list here needs to be the same as the number of regions (excl background). Each element in the list corresponds
to the label that is placed instead of the region into the final segmentation. Later entries will overwrite earlier ones!
Concretely, for the example given here, nnU-Net
will firstly place the label 1 (edema) where the 'whole_tumor' region was predicted, then place the label 2
(non-enhancing tumor and necrosis) where the "tumor_core" was predicted and finally place the label 3 in the
predicted 'enhancing_tumor' area. With each step, part of the previously set pixels
will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions
(like whole tumor etc) first, followed by substructures.
**IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are
declared ("place label X in the first region") you need to make sure that this order is not perturbed! When
automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set
`sort_keys=False` in `json.dump()`!!!
nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels!
That's all. Easy, huh?
================================================
FILE: documentation/resenc_presets.md
================================================
# Residual Encoder Presets in nnU-Net
When using these presets, please cite our recent paper on the need for rigorous validation in 3D medical image segmentation:
> Isensee, F.* , Wald, T.* , Ulrich, C.* , Baumgartner, M.* , Roy, S., Maier-Hein, K.†, Jaeger, P.† (2024). nnU-Net Revisited: A Call for Rigorous Validation in 3D Medical Image Segmentation. arXiv preprint arXiv:2404.09556.
*: shared first authors\
†: shared last authors
[PAPER LINK](https://arxiv.org/pdf/2404.09556.pdf)
Residual Encoder UNets have been supported by nnU-Net since our participation in KiTS2019, but have flown under the radar.
This is bound to change with our new nnUNetResEncUNet presets :raised_hands:! Especially on large datasets such as KiTS2023 and AMOS2022
they offer improved segmentation performance!
| | BTCV | ACDC | LiTS | BraTS | KiTS | AMOS | VRAM | RT | Arch. | nnU |
|------------------------|-------|-------|-------|-------|-------|-------|-------|-----|-------|-----|
| | n=30 | n=200 | n=131 | n=1251| n=489 | n=360 | | | | |
| nnU-Net (org.) [1] | 83.08 | 91.54 | 80.09 | 91.24 | 86.04 | 88.64 | 7.70 | 9 | CNN | Yes |
| nnU-Net ResEnc M | 83.31 | 91.99 | 80.75 | 91.26 | 86.79 | 88.77 | 9.10 | 12 | CNN | Yes |
| nnU-Net ResEnc L | 83.35 | 91.69 | 81.60 | 91.13 | 88.17 | 89.41 | 22.70 | 35 | CNN | Yes |
| nnU-Net ResEnc XL | 83.28 | 91.48 | 81.19 | 91.18 | 88.67 | 89.68 | 36.60 | 66 | CNN | Yes |
| MedNeXt L k3 [2] | 84.70 | 92.65 | 82.14 | 91.35 | 88.25 | 89.62 | 17.30 | 68 | CNN | Yes |
| MedNeXt L k5 [2] | 85.04 | 92.62 | 82.34 | 91.50 | 87.74 | 89.73 | 18.00 | 233 | CNN | Yes |
| STU-Net S [3] | 82.92 | 91.04 | 78.50 | 90.55 | 84.93 | 88.08 | 5.20 | 10 | CNN | Yes |
| STU-Net B [3] | 83.05 | 91.30 | 79.19 | 90.85 | 86.32 | 88.46 | 8.80 | 15 | CNN | Yes |
| STU-Net L [3] | 83.36 | 91.31 | 80.31 | 91.26 | 85.84 | 89.34 | 26.50 | 51 | CNN | Yes |
| SwinUNETR [4] | 78.89 | 91.29 | 76.50 | 90.68 | 81.27 | 83.81 | 13.10 | 15 | TF | Yes |
| SwinUNETRV2 [5] | 80.85 | 92.01 | 77.85 | 90.74 | 84.14 | 86.24 | 13.40 | 15 | TF | Yes |
| nnFormer [6] | 80.86 | 92.40 | 77.40 | 90.22 | 75.85 | 81.55 | 5.70 | 8 | TF | Yes |
| CoTr [7] | 81.95 | 90.56 | 79.10 | 90.73 | 84.59 | 88.02 | 8.20 | 18 | TF | Yes |
| No-Mamba Base | 83.69 | 91.89 | 80.57 | 91.26 | 85.98 | 89.04 | 12.0 | 24 | CNN | Yes |
| U-Mamba Bot [8] | 83.51 | 91.79 | 80.40 | 91.26 | 86.22 | 89.13 | 12.40 | 24 | Mam | Yes |
| U-Mamba Enc [8] | 82.41 | 91.22 | 80.27 | 90.91 | 86.34 | 88.38 | 24.90 | 47 | Mam | Yes |
| A3DS SegResNet [9,11] | 80.69 | 90.69 | 79.28 | 90.79 | 81.11 | 87.27 | 20.00 | 22 | CNN | No |
| A3DS DiNTS [10, 11] | 78.18 | 82.97 | 69.05 | 87.75 | 65.28 | 82.35 | 29.20 | 16 | CNN | No |
| A3DS SwinUNETR [4, 11] | 76.54 | 82.68 | 68.59 | 89.90 | 52.82 | 85.05 | 34.50 | 9 | TF | No |
Results taken from our paper (see above), reported values are Dice scores computed over 5-fold cross-validation on each
dataset. All models trained from scratch.
RT: training run time (measured on 1x Nvidia A100 PCIe 40GB)\
VRAM: GPU VRAM used during training, as reported by nvidia-smi\
Arch.: CNN = convolutional neural network; TF = transformer; Mam = Mamba\
nnU: whether the architectrue was integrated and tested with the nnU-Net framework (either by us or the original authors)
## How to use the new presets
We offer three new presets, each targeted for a different GPU VRAM and compute budget:
- **nnU-Net ResEnc M**: similar GPU budget to the standard UNet configuration. Best suited for GPUs with 9-11GB VRAM. Training time: ~12h on A100
- **nnU-Net ResEnc L**: requires a GPU with 24GB VRAM. Training time: ~35h on A100
- **nnU-Net ResEnc XL**: requires a GPU with 40GB VRAM. Training time: ~66h on A100
### **:point_right: We recommend **nnU-Net ResEnc L** as the new default nnU-Net configuration! :point_left:**
The new presets are available as follows ((M/L/XL) = pick one!):
1. Specify the desired configuration when running experiment planning and preprocessing:
`nnUNetv2_plan_and_preprocess -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)`. These planners use the same preprocessed
data folder as the standard 2d and 3d_fullres configurations since the preprocessed data is identical. Only the
3d_lowres differs and will be saved in a different folder to allow all configurations to coexist! If you are only
planning to run 3d_fullres/2d and you already have this data preprocessed, you can just run
`nnUNetv2_plan_experiment -d DATASET -pl nnUNetPlannerResEnc(M/L/XL)` to avoid preprocessing again!
2. Now, just specify the correct plans when running `nnUNetv2_train`, `nnUNetv2_predict` etc. The interface is
consistent across all nnU-Net commands: `-p nnUNetResEncUNet(M/L/XL)Plans`
Training results for the new presets will be stored in a dedicated folder and will not overwrite standard nnU-Net
results! So don't be afraid to give it a go!
## Scaling ResEnc nnU-Net beyond the Presets
The presets differ from `ResEncUNetPlanner` in two ways:
- They set new default values for `gpu_memory_target_in_gb` to target the respective VRAM consumptions
- They remove the batch size cap of 0.05 (= previously one batch could not cover mode pixels than 5% of the entire dataset, not it can be arbitrarily large)
The presets are merely there to make life easier, and to provide standardized configurations people can benchmark with.
You can easily adapt the GPU memory target to match your GPU, and to scale beyond 40GB of GPU memory.
Here is an example for how to scale to 80GB VRAM on Dataset003_Liver:
`nnUNetv2_plan_experiment -d 3 -pl nnUNetPlannerResEncM -gpu_memory_target 80 -overwrite_plans_name nnUNetResEncUNetPlans_80G`
Just use `-p nnUNetResEncUNetPlans_80G` moving forward as outlined above! Running the example above will yield a
warning ("You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb"). This warning can be ignored here.
**Always change the plans identifier with `-overwrite_plans_name NEW_PLANS_NAME` when messing with the VRAM target in
order to not overwrite preset plans!**
Why not use `ResEncUNetPlanner` -> because that one still has the 5% cap in place!
### Scaling to multiple GPUs
When scaling to multiple GPUs, do not just specify the combined amount of VRAM to `nnUNetv2_plan_experiment` as this
may result in patch sizes that are too large to be processed by individual GPUs. It is best to let this command run for
the VRAM budget of one GPU, and then manually edit the plans file to increase the batch size. You can use [configuration inheritance](explanation_plans_files.md).
In the configurations dictionary of the generated plans JSON file, add the following entry:
```json
"3d_fullres_bsXX": {
"inherits_from": "3d_fullres",
"batch_size": XX
},
```
Where XX is the new batch size. If 3d_fullres has a batch size of 2 for one GPU and you are planning to scale to 8 GPUs, make the new batch size 2x8=16!
You can then train the new configuration using nnU-Net's multi-GPU settings:
```bash
nnUNetv2_train DATASETID 3d_fullres_bsXX FOLD -p nnUNetResEncUNetPlans_80G -num_gpus 8
```
## Proposing a new segmentation method? Benchmark the right way!
When benchmarking new segmentation methods against nnU-Net, we encourage to benchmark against the residual encoder
variants. For a fair comparison, pick the variant that most closely matches the GPU memory and compute
requirements of your method!
## References
[1] Isensee, Fabian, et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nature methods 18.2 (2021): 203-211.\
[2] Roy, Saikat, et al. "Mednext: transformer-driven scaling of convnets for medical image segmentation." International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\
[3] Huang, Ziyan, et al. "Stu-net: Scalable and transferable medical image segmentation models empowered by large-scale supervised pre-training." arXiv preprint arXiv:2304.06716 (2023).\
[4] Hatamizadeh, Ali, et al. "Swin unetr: Swin transformers for semantic segmentation of brain tumors in mri images." International MICCAI Brainlesion Workshop. Cham: Springer International Publishing, 2021.\
[5] He, Yufan, et al. "Swinunetr-v2: Stronger swin transformers with stagewise convolutions for 3d medical image segmentation." International Conference on Medical Image Computing and Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.\
[6] Zhou, Hong-Yu, et al. "nnformer: Interleaved transformer for volumetric segmentation." arXiv preprint arXiv:2109.03201 (2021).\
[7] Xie, Yutong, et al. "Cotr: Efficiently bridging cnn and transformer for 3d medical image segmentation." Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part III 24. Springer International Publishing, 2021.\
[8] Ma, Jun, Feifei Li, and Bo Wang. "U-mamba: Enhancing long-range dependency for biomedical image segmentation." arXiv preprint arXiv:2401.04722 (2024).\
[9] Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries: 4th International Workshop, BrainLes 2018, Held in Conjunction with MICCAI 2018, Granada, Spain, September 16, 2018, Revised Selected Papers, Part II 4. Springer International Publishing, 2019.\
[10] He, Yufan, et al. "Dints: Differentiable neural network topology search for 3d medical image segmentation." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\
[11] Auto3DSeg, MONAI 1.3.0, [LINK](https://github.com/Project-MONAI/tutorials/tree/ed8854fa19faa49083f48abf25a2c30ab9ac1c6b/auto3dseg)
================================================
FILE: documentation/run_inference_with_pretrained_models.md
================================================
# How to run inference with pretrained models
**Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new
version. But honestly, you already have a fully trained model with which you can run inference (in v1), so
just continue using that!
Not yet available for V2 :-(
If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam!
================================================
FILE: documentation/set_environment_variables.md
================================================
# How to set environment variables
nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained
models are. Depending on the operating system, these environment variables need to be set in different ways.
Variables can either be set permanently (recommended!) or you can decide to set them every time you call nnU-Net.
# Linux & MacOS
## Permanent
Locate the `.bashrc` file in your home folder and add the following lines to the bottom:
```bash
export nnUNet_raw="/media/fabian/nnUNet_raw"
export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
export nnUNet_results="/media/fabian/nnUNet_results"
```
(Of course you need to adapt the paths to the actual folders you intend to use).
If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`.
## Temporary
Just execute the following lines whenever you run nnU-Net:
```bash
export nnUNet_raw="/media/fabian/nnUNet_raw"
export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
export nnUNet_results="/media/fabian/nnUNet_results"
```
(Of course you need to adapt the paths to the actual folders you intend to use).
Important: These variables will be deleted if you close your terminal! They will also only apply to the current
terminal window and DO NOT transfer to other terminals!
Alternatively you can also just prefix them to your nnU-Net commands:
`nnUNet_results="/media/fabian/nnUNet_results" nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" nnUNetv2_train[...]`
## Verify that environment parameters are set
You can always execute `echo ${nnUNet_raw}` etc to print the environment variables. This will return an empty string if
they were not set.
# Windows
Useful links:
- [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.)
- [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable)
## Permanent
See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable).
Or read about setx (command prompt).
## Temporary
Just execute the following before you run nnU-Net:
(PowerShell)
```PowerShell
$Env:nnUNet_raw = "C:/Users/fabian/nnUNet_raw"
$Env:nnUNet_preprocessed = "C:/Users/fabian/nnUNet_preprocessed"
$Env:nnUNet_results = "C:/Users/fabian/nnUNet_results"
```
(Command Prompt)
```Command Prompt
set nnUNet_raw=C:/Users/fabian/nnUNet_raw
set nnUNet_preprocessed=C:/Users/fabian/nnUNet_preprocessed
set nnUNet_results=C:/Users/fabian/fabian/nnUNet_results
```
(Of course you need to adapt the paths to the actual folders you intend to use).
Important: These variables will be deleted if you close your session! They will also only apply to the current
window and DO NOT transfer to other sessions!
## Verify that environment parameters are set
Printing in Windows works differently depending on the environment you are in:
PowerShell: `echo $Env:[variable_name]`
Command Prompt: `echo %variable_name%`
================================================
FILE: documentation/setting_up_paths.md
================================================
# Setting up Paths
nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored.
To use the full functionality of nnU-Net, the following three environment variables must be set:
1) `nnUNet_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names
DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique)
dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md).
Example tree structure:
```
nnUNet_raw/Dataset001_NAME1
├── dataset.json
├── imagesTr
│ ├── ...
├── imagesTs
│ ├── ...
└── labelsTr
├── ...
nnUNet_raw/Dataset002_NAME2
├── dataset.json
├── imagesTr
│ ├── ...
├── imagesTs
│ ├── ...
└── labelsTr
├── ...
```
2) `nnUNet_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from
this folder during training. It is important that this folder is located on a drive with low access latency and high
throughput (such as a nvme SSD (PCIe gen 3 is sufficient)).
3) `nnUNet_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this
is where it will save them.
### How to set environment variables
See [here](set_environment_variables.md).
================================================
FILE: documentation/tldr_migration_guide_from_v1.md
================================================
# TLDR Migration Guide from nnU-Net V1
- nnU-Net V2 can be installed simultaneously with V1. They won't get in each other's way
- The environment variables needed for V2 have slightly different names. Read [this](setting_up_paths.md).
- nnU-Net V2 datasets are called DatasetXXX_NAME. Not Task.
- Datasets have the same structure (imagesTr, labelsTr, dataset.json) but we now support more
[file types](dataset_format.md#supported-file-formats). The dataset.json is simplified. Use `generate_dataset_json`
from nnunetv2.dataset_conversion.generate_dataset_json.py.
- Careful: labels are now no longer declared as value:name but name:value. This has to do with [hierarchical labels](region_based_training.md).
- nnU-Net v2 commands start with `nnUNetv2...`. They work mostly (but not entirely) the same. Just use the `-h` option.
- You can transfer your V1 raw datasets to V2 with `nnUNetv2_convert_old_nnUNet_dataset`. You cannot transfer trained
models. Continue to use the old nnU-Net Version for making inference with those.
- These are the commands you are most likely to be using (in that order)
- `nnUNetv2_plan_and_preprocess`. Example: `nnUNetv2_plan_and_preprocess -d 2`
- `nnUNetv2_train`. Example: `nnUNetv2_train 2 3d_fullres 0`
- `nnUNetv2_find_best_configuration`. Example: `nnUNetv2_find_best_configuration 2 -c 2d 3d_fullres`. This command
will now create a `inference_instructions.txt` file in your `nnUNet_preprocessed/DatasetXXX_NAME/` folder which
tells you exactly how to do inference.
- `nnUNetv2_predict`. Example: `nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -c 3d_fullres -d 2`
- `nnUNetv2_apply_postprocessing` (see inference_instructions.txt)
================================================
FILE: nnunetv2/__init__.py
================================================
================================================
FILE: nnunetv2/batch_running/__init__.py
================================================
================================================
FILE: nnunetv2/batch_running/benchmarking/__init__.py
================================================
================================================
FILE: nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py
================================================
if __name__ == '__main__':
"""
This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler!
"""
gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB',
'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB']
datasets = [2, 3, 4, 5]
trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']
plans = ['nnUNetPlans']
configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']
num_gpus = 1
benchmark_configurations = {d: configs for d in datasets}
exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\""
resources = "-R \"tensorcore\""
queue = "-q gpu"
preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && "
train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train'
folds = (0, )
use_these_modules = {
tr: plans for tr in trainers
}
additional_arguments = f' -num_gpus {num_gpus}' # ''
output_file = "/home/isensee/deleteme.txt"
with open(output_file, 'w') as f:
for g in gpu_models:
gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}"
for tr in use_these_modules.keys():
for p in use_these_modules[tr]:
for dataset in benchmark_configurations.keys():
for config in benchmark_configurations[dataset]:
for fl in folds:
command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
if additional_arguments is not None and len(additional_arguments) > 0:
command += f' {additional_arguments}'
f.write(f'{command}\"\n')
================================================
FILE: nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py
================================================
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.file_path_utilities import get_output_folder
if __name__ == '__main__':
trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']
datasets = [2, 3, 4, 5]
plans = ['nnUNetPlans']
configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']
output_file = join(nnUNet_results, 'benchmark_results.csv')
torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" #
cudnn_version = 8700 # 8302 #
num_gpus = 1
unique_gpus = set()
# collect results in the most janky way possible. Amazing coding skills!
all_results = {}
for tr in trainers:
all_results[tr] = {}
for p in plans:
all_results[tr][p] = {}
for c in configs:
all_results[tr][p][c] = {}
for d in datasets:
dataset_name = maybe_convert_to_dataset_name(d)
output_folder = get_output_folder(dataset_name, tr, p, c, fold=0)
expected_benchmark_file = join(output_folder, 'benchmark_result.json')
all_results[tr][p][c][d] = {}
if isfile(expected_benchmark_file):
# filter results for what we want
results = [i for i in load_json(expected_benchmark_file).values()
if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and
i['torch_version'] == torch_version]
for r in results:
all_results[tr][p][c][d][r['gpu_name']] = r
unique_gpus.add(r['gpu_name'])
# haha. Fuck this. Collect GPUs in the code above.
# unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]])
unique_gpus = list(unique_gpus)
unique_gpus.sort()
with open(output_file, 'w') as f:
f.write('Dataset,Trainer,Plans,Config')
for g in unique_gpus:
f.write(f",{g}")
f.write("\n")
for d in datasets:
for tr in trainers:
for p in plans:
for c in configs:
gpu_results = []
for g in unique_gpus:
if g in all_results[tr][p][c][d].keys():
gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2))
else:
gpu_results.append("MISSING")
# skip if all are missing
if all([i == 'MISSING' for i in gpu_results]):
continue
f.write(f"{d},{tr},{p},{c}")
for g in gpu_results:
f.write(f",{g}")
f.write("\n")
f.write("\n")
================================================
FILE: nnunetv2/batch_running/collect_results_custom_Decathlon.py
================================================
from typing import Tuple
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.evaluation.evaluate_predictions import load_summary_json
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id
from nnunetv2.utilities.file_path_utilities import get_output_folder
def collect_results(trainers: dict, datasets: List, output_file: str,
configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"),
folds=tuple(np.arange(5))):
results_dirs = (nnUNet_results,)
datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]
with open(output_file, 'w') as f:
for i, d in zip(datasets, datasets_names):
for c in configurations:
for module in trainers.keys():
for plans in trainers[module]:
for r in results_dirs:
expected_output_folder = get_output_folder(d, module, plans, c)
if isdir(expected_output_folder):
results_folds = []
f.write(f"{d},{c},{module},{plans},{r}")
for fl in folds:
expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)
expected_summary_file = join(expected_output_folder_fold, "validation",
"summary.json")
if not isfile(expected_summary_file):
print('expected output file not found:', expected_summary_file)
f.write(",")
results_folds.append(np.nan)
else:
foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][
'Dice']
results_folds.append(foreground_mean)
f.write(f",{foreground_mean:02.4f}")
f.write(f",{np.nanmean(results_folds):02.4f}\n")
def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):
txt = np.loadtxt(input_file, dtype=str, delimiter=',')
num_folds = txt.shape[1] - 6
valid_configs = {}
for d in datasets:
if isinstance(d, int):
d = maybe_convert_to_dataset_name(d)
configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])
valid_configs[d] = [i for i in configs_in_txt if i in configs]
assert max(folds) < num_folds
with open(output_file, 'w') as f:
f.write("name")
for d in valid_configs.keys():
for c in valid_configs[d]:
f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4]))
f.write(',mean\n')
valid_entries = txt[:, 4] == nnUNet_results
for t in trainers.keys():
trainer_locs = valid_entries & (txt[:, 2] == t)
for pl in trainers[t]:
f.write(f"{t}__{pl}")
trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)
r = []
for d in valid_configs.keys():
trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)
for v in valid_configs[d]:
trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)
if np.any(trainer_plan_d_config_locs):
# we cannot have more than one row
assert np.sum(trainer_plan_d_config_locs) == 1
# now check that we have all folds
selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]
fold_results = selected_row[[i + 5 for i in folds]]
if '' in fold_results:
print('missing fold in', t, pl, d, v)
f.write(",nan")
r.append(np.nan)
else:
mean_dice = np.mean([float(i) for i in fold_results])
f.write(f",{mean_dice:02.4f}")
r.append(mean_dice)
else:
print('missing:', t, pl, d, v)
f.write(",nan")
r.append(np.nan)
f.write(f",{np.mean(r):02.4f}\n")
if __name__ == '__main__':
use_these_trainers = {
'nnUNetTrainer': ('nnUNetResEncUNetLPlans', ),
}
all_results_file= join(nnUNet_results, 'customDecResults.csv')
datasets = [3, 5, 8, 10, 17, 27, 55, 220, 223, 226, 219]
# datasets = [3, 4, 5, 8, 10, 17, 27, 55, 220, 223]
collect_results(use_these_trainers, datasets, all_results_file)
# folds = (0, 1, 2, 3, 4)
# configs = ("2d", )
# output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')
# summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
folds = (0, )
configs = ("3d_fullres", )
output_file = join(nnUNet_results, 'summary_fold0.csv')
summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
================================================
FILE: nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize
from nnunetv2.paths import nnUNet_results
if __name__ == '__main__':
use_these_trainers = {
'nnUNetTrainer': ('nnUNetPlans', ),
}
all_results_file = join(nnUNet_results, 'hrnet_results.csv')
datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82]
collect_results(use_these_trainers, datasets, all_results_file)
folds = (0, )
configs = ('2d', )
output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv')
summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
================================================
FILE: nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py
================================================
from copy import deepcopy
import numpy as np
def merge(dict1, dict2):
keys = np.unique(list(dict1.keys()) + list(dict2.keys()))
keys = np.unique(keys)
res = {}
for k in keys:
all_configs = []
if dict1.get(k) is not None:
all_configs += list(dict1[k])
if dict2.get(k) is not None:
all_configs += list(dict2[k])
if len(all_configs) > 0:
res[k] = tuple(np.unique(all_configs))
return res
if __name__ == "__main__":
# after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of
# datasets for evaluation and future development
configurations_all = {
2: ("3d_fullres", "2d"),
3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
4: ("2d", "3d_fullres"),
5: ("2d", "3d_fullres"),
8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
24: ("2d", "3d_fullres"),
27: ("2d", "3d_fullres"),
38: ("2d", "3d_fullres"),
55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
137: ("2d", "3d_fullres"),
220: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 221: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
223: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
219: ("2d", "3d_fullres"),
226: ("2d", "3d_fullres"),
}
configurations_3d_fr_only = {
i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i]
}
configurations_3d_c_only = {
i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i]
}
configurations_3d_lr_only = {
i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i]
}
configurations_2d_only = {
i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i]
}
num_gpus = 1
exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\""
resources = ""
gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=23G"#gmodel=NVIDIAA100_PCIE_40GB"
queue = "-q gpu-pro"
preamble = "\". /home/isensee/env_loading_scripts/continuous_performance_monitoring/load_env_torch280.sh && " # -L /bin/bash
train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/results_nnUNet_master nnUNetv2_train'
folds = (0, )
# use_this = configurations_2d_only
use_this = configurations_3d_fr_only
# use_this = merge(use_this, configurations_3d_c_only)
datasets = [3, 5, 8, 10, 17, 27, 55, 219, 220, 223, 226]
use_this = {i: use_this[i] for i in datasets}
use_these_modules = {
# 'nnUNetTrainer_newSpatialAug': ('nnUNetPlans',),
# 'nnUNetTrainerBN': ('nnUNetPlans',),
# 'nnUNetTrainer_newSpatialAug_withElDef_noPref': ('nnUNetPlans',),
# 'nnUNetTrainer': ('nnUNetConvNextEncUNetPlans_smallks_and_shallow',),
# 'nnUNetTrainer_convnextenc_regularconvblock': ('nnUNetPlans_convnext', 'nnUNetConvNextEncUNetPlans_smallks_and_shallow'),
# 'nnUNetTrainer_newSpatialAug_withElDef2': ('nnUNetPlans',),
# 'nnUNetTrainer_newSpatialAug_withElDef3': ('nnUNetPlans',),
# 'nnUNetTrainer_newSpatialAug_noPref': ('nnUNetPlans',),
# 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',),
# 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',),
# 'nnUNetTrainerAdamW_WDe2': ('nnUNetPlans',),
# 'nnUNetTrainerUMambaBot': ('nnUNetPlans',),
# 'nnUNetTrainerUMambaEnc': ('nnUNetPlans',),
# 'nnUNetTrainer_fasterDA': ('nnUNetPlans', 'nnUNetResEncUNetLPlans'),
# 'nnUNetTrainer_noDummy2DDA': ('nnUNetResEncUNetMPlans', ),
'nnUNetTrainer': ('nnUNetResEncUNetMPlans', ),
'nnUNetTrainerDA5': ('nnUNetResEncUNetMPlans', ),
# 'nnUNetTrainer_probabilisticOversampling_033': ('nnUNetResEncUNetMPlans', ),
# 'nnUNetTrainer_probabilisticOversampling_010': ('nnUNetResEncUNetMPlans',),
# BN
}
additional_arguments = f' -num_gpus {num_gpus} --disable_checkpointing' # ''
output_file = "/home/isensee/deleteme.txt"
with open(output_file, 'w') as f:
for tr in use_these_modules.keys():
for p in use_these_modules[tr]:
for dataset in use_this.keys():
for config in use_this[dataset]:
for fl in folds:
command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
if additional_arguments is not None and len(additional_arguments) > 0:
command += f' {additional_arguments}'
f.write(f'{command}\"\n')
================================================
FILE: nnunetv2/batch_running/jobs.sh
================================================
# lsf22-gpu01
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
# lsf22-gpu03
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
# lsf22-gpu05
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
# lsf22-gpu06
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
# lsf22-gpu07
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=1 nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=2 nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=3 nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=5 nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer_noDummy2DDA -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=6 nnUNetv2_train 3 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
screen -dm bash -c ". ~/load_env_torch224_balintsfix.sh && CUDA_VISIBLE_DEVICES=7 nnUNetv2_train 4 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
# launched as jobs
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 5 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 8 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 10 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 17 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 27 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 55 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 220 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
bsub -R "select[hname!='e230-dgx2-2']" -R "select[hname!='e230-dgx2-1']" -q gpu -gpu num=1:j_exclusive=yes:gmem=1G ". ~/load_env_torch224_balintsfix.sh && nnUNetv2_train 223 3d_fullres 0 -tr nnUNetTrainer -p nnUNetResEncUNetMPlans --disable_checkpointing"
================================================
FILE: nnunetv2/batch_running/release_trainings/__init__.py
================================================
================================================
FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py
================================================
================================================
FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py
================================================
from typing import Tuple
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.evaluation.evaluate_predictions import load_summary_json
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id
from nnunetv2.utilities.file_path_utilities import get_output_folder
def collect_results(trainers: dict, datasets: List, output_file: str,
configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"),
folds=tuple(np.arange(5))):
results_dirs = (nnUNet_results,)
datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]
with open(output_file, 'w') as f:
for i, d in zip(datasets, datasets_names):
for c in configurations:
for module in trainers.keys():
for plans in trainers[module]:
for r in results_dirs:
expected_output_folder = get_output_folder(d, module, plans, c)
if isdir(expected_output_folder):
results_folds = []
f.write(f"{d},{c},{module},{plans},{r}")
for fl in folds:
expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)
expected_summary_file = join(expected_output_folder_fold, "validation",
"summary.json")
if not isfile(expected_summary_file):
print('expected output file not found:', expected_summary_file)
f.write(",")
results_folds.append(np.nan)
else:
foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][
'Dice']
results_folds.append(foreground_mean)
f.write(f",{foreground_mean:02.4f}")
f.write(f",{np.nanmean(results_folds):02.4f}\n")
def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):
txt = np.loadtxt(input_file, dtype=str, delimiter=',')
num_folds = txt.shape[1] - 6
valid_configs = {}
for d in datasets:
if isinstance(d, int):
d = maybe_convert_to_dataset_name(d)
configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])
valid_configs[d] = [i for i in configs_in_txt if i in configs]
assert max(folds) < num_folds
with open(output_file, 'w') as f:
f.write("name")
for d in valid_configs.keys():
for c in valid_configs[d]:
f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4]))
f.write(',mean\n')
valid_entries = txt[:, 4] == nnUNet_results
for t in trainers.keys():
trainer_locs = valid_entries & (txt[:, 2] == t)
for pl in trainers[t]:
f.write(f"{t}__{pl}")
trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)
r = []
for d in valid_configs.keys():
trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)
for v in valid_configs[d]:
trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)
if np.any(trainer_plan_d_config_locs):
# we cannot have more than one row
assert np.sum(trainer_plan_d_config_locs) == 1
# now check that we have all folds
selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]
fold_results = selected_row[[i + 5 for i in folds]]
if '' in fold_results:
print('missing fold in', t, pl, d, v)
f.write(",nan")
r.append(np.nan)
else:
mean_dice = np.mean([float(i) for i in fold_results])
f.write(f",{mean_dice:02.4f}")
r.append(mean_dice)
else:
print('missing:', t, pl, d, v)
f.write(",nan")
r.append(np.nan)
f.write(f",{np.mean(r):02.4f}\n")
if __name__ == '__main__':
use_these_trainers = {
'nnUNetTrainer': ('nnUNetPlans',),
'nnUNetTrainer_v1loss': ('nnUNetPlans',),
}
all_results_file = join(nnUNet_results, 'customDecResults.csv')
datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82]
collect_results(use_these_trainers, datasets, all_results_file)
folds = (0, 1, 2, 3, 4)
configs = ("3d_fullres", "3d_lowres")
output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')
summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
folds = (0, )
configs = ("3d_fullres", "3d_lowres")
output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv')
summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
================================================
FILE: nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py
================================================
from copy import deepcopy
import numpy as np
def merge(dict1, dict2):
keys = np.unique(list(dict1.keys()) + list(dict2.keys()))
keys = np.unique(keys)
res = {}
for k in keys:
all_configs = []
if dict1.get(k) is not None:
all_configs += list(dict1[k])
if dict2.get(k) is not None:
all_configs += list(dict2[k])
if len(all_configs) > 0:
res[k] = tuple(np.unique(all_configs))
return res
if __name__ == "__main__":
# after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of
# datasets for evaluation and future development
configurations_all = {
# 1: ("3d_fullres", "2d"),
2: ("3d_fullres", "2d"),
# 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 4: ("2d", "3d_fullres"),
5: ("2d", "3d_fullres"),
# 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
20: ("2d", "3d_fullres"),
24: ("2d", "3d_fullres"),
27: ("2d", "3d_fullres"),
35: ("2d", "3d_fullres"),
38: ("2d", "3d_fullres"),
# 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
# 82: ("2d", "3d_fullres"),
# 83: ("2d", "3d_fullres"),
}
configurations_3d_fr_only = {
i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i]
}
configurations_3d_c_only = {
i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i]
}
configurations_3d_lr_only = {
i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i]
}
configurations_2d_only = {
i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i]
}
num_gpus = 1
exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\""
resources = "-R \"tensorcore\""
gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G"
queue = "-q gpu-lowprio"
preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && "
train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train'
folds = (0, 1, 2, 3, 4)
# use_this = configurations_2d_only
# use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only)
# use_this = merge(use_this, configurations_3d_c_only)
use_this = configurations_all
use_these_modules = {
'nnUNetTrainer': ('nnUNetPlans',),
}
additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # ''
output_file = "/home/isensee/deleteme.txt"
with open(output_file, 'w') as f:
for tr in use_these_modules.keys():
for p in use_these_modules[tr]:
for dataset in use_this.keys():
for config in use_this[dataset]:
for fl in folds:
command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
if additional_arguments is not None and len(additional_arguments) > 0:
command += f' {additional_arguments}'
f.write(f'{command}\"\n')
================================================
FILE: nnunetv2/configuration.py
================================================
import os
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])
ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
# resolution axis must be 3x as large as the next largest spacing)
default_n_proc_DA = get_allowed_n_proc_DA()
================================================
FILE: nnunetv2/dataset_conversion/Dataset015_018_RibFrac_RibSeg.py
================================================
from copy import deepcopy
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
import SimpleITK as sitk
if __name__ == '__main__':
"""
Download RibFrac dataset. Links are at https://ribfrac.grand-challenge.org/
Download everything. Part1, 2, validation and test
Extract EVERYTHING into one folder so that all images and labels are in there. Don't worry they all have unique
file names.
For RibSeg also download the dataset from https://github.com/M3DV/RibSeg
(https://drive.google.com/file/d/1ZZGGrhd0y1fLyOZGo_Y-wlVUP4lkHVgm/view?usp=sharing) and extract in to that same
folder (seg only, files end with -rib-seg.nii.gz)
"""
# extracted traiing.zip file is here
base = '/home/isensee/Downloads/RibFrac_all'
files = nifti_files(base, join=False)
identifiers = np.unique([i.split('-')[0] for i in files])
# RibFrac
target_dataset_id = 15
target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibFrac'
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(imagesTs)
maybe_mkdir_p(labelsTr)
n_tr = 0
for c in identifiers:
print(c)
img_file = join(base, c + '-image.nii.gz')
seg_file = join(base, c + '-label.nii.gz')
if not isfile(seg_file):
# test case
shutil.copy(img_file, join(imagesTs, c + '_0000.nii.gz'))
continue
n_tr += 1
shutil.copy(img_file, join(imagesTr, c + '_0000.nii.gz'))
# we must open seg and map -1 to 5
seg_itk = sitk.ReadImage(seg_file)
seg_npy = sitk.GetArrayFromImage(seg_itk)
seg_npy[seg_npy == -1] = 5
seg_itk_out = sitk.GetImageFromArray(seg_npy.astype(np.uint8))
seg_itk_out.SetSpacing(seg_itk.GetSpacing())
seg_itk_out.SetDirection(seg_itk.GetDirection())
seg_itk_out.SetOrigin(seg_itk.GetOrigin())
sitk.WriteImage(seg_itk_out, join(labelsTr, c + '.nii.gz'))
# - 0: it is background
# - 1: it is a displaced rib fracture
# - 2: it is a non-displaced rib fracture
# - 3: it is a buckle rib fracture
# - 4: it is a segmental rib fracture
# - -1: it is a rib fracture, but we could not define its type due to
# ambiguity, diagnosis difficulty, etc. Ignore it in the
# classification task.
generate_dataset_json(
join(nnUNet_raw, target_dataset_name),
channel_names={0: 'CT'},
labels = {
'background': 0,
'fracture': (1, 2, 3, 4, 5),
'displaced rib fracture': 1,
'non-displaced rib fracture': 2,
'buckle rib fracture': 3,
'segmental rib fracture': 4,
},
num_training_cases=n_tr,
file_ending='.nii.gz',
regions_class_order=(5, 1, 2, 3, 4),
dataset_name=target_dataset_name,
reference='https://ribfrac.grand-challenge.org/'
)
# RibSeg
# overall I am not happy with the GT quality here. But eh what can I do
target_dataset_name_ribfrac = deepcopy(target_dataset_name)
target_dataset_id = 18
target_dataset_name = f'Dataset{target_dataset_id:03.0f}_RibSeg'
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)
# the authors have a google shet where they highlight problems with their dataset:
# https://docs.google.com/spreadsheets/d/1lz9liWPy8yHybKCdO3BCA9K76QH8a54XduiZS_9fK70/edit?gid=1416415020#gid=1416415020
# we exclude the cases marked in red. They have unannotated ribs
skip_identifiers = [
'RibFrac452',
'RibFrac485',
'RibFrac490',
'RibFrac471',
'RibFrac462',
'RibFrac487',
]
n_tr = 0
dataset = {}
for c in identifiers:
if c in skip_identifiers:
continue
print(c)
tr_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTr', c + '_0000.nii.gz')
ts_file = join('$nnUNet_raw', target_dataset_name_ribfrac, 'imagesTs', c + '_0000.nii.gz')
if isfile(os.path.expandvars(tr_file)):
img_file = tr_file
elif isfile(os.path.expandvars(ts_file)):
img_file = ts_file
else:
raise RuntimeError(f'Missing image file for identifier {identifiers}')
seg_file = join(base, c + '-rib-seg.nii.gz')
n_tr += 1
shutil.copy(seg_file, join(labelsTr, c + '.nii.gz'))
dataset[c] = {
'images': [img_file],
'label': join('labelsTr', c + '.nii.gz')
}
generate_dataset_json(
join(nnUNet_raw, target_dataset_name),
channel_names={0: 'CT'},
labels = {
'background': 0,
**{'rib%02.0d' % i: i for i in range(1, 25)}
},
num_training_cases=n_tr,
file_ending='.nii.gz',
dataset_name=target_dataset_name,
reference='https://github.com/M3DV/RibSeg, https://ribfrac.grand-challenge.org/',
dataset=dataset
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset021_CTAAorta.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
import SimpleITK as sitk
if __name__ == '__main__':
"""
"""
# extracted traiing.zip file is here
base = '/home/isensee/Downloads/'
target_dataset_id = 21
target_dataset_name = f'Dataset{target_dataset_id:03.0f}_CTAAorta'
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)
cases = subfiles(join(base, 'images'), join=False, prefix='subject')
for case in cases:
outname = case.replace('_CTA.mha', '')
im = sitk.ReadImage(join(base, 'images', case))
sitk.WriteImage(im, join(imagesTr, outname + '_0000.nii.gz'))
seg = sitk.ReadImage(join(base, 'masks', case.replace('_CTA.mha', '_label.mha')))
sitk.WriteImage(seg, join(labelsTr, outname + '.nii.gz'))
labels = {
"background": 0,
"Zone_0": 1,
"Innominate": 2,
"Zone_1": 3,
"Left_Common_Carotid": 4,
"Zone_2": 5,
"Left_Subclavian_Artery": 6,
"Zone_3": 7,
"Zone_4": 8,
"Zone_5": 9,
"Zone_6": 10,
"Celiac_Artery": 11,
"Zone_7": 12,
"SMA": 13,
"Zone_8": 14,
"Right_Renal_Artery": 15,
"Left_Renal_Artery": 16,
"Zone_9": 17,
"Zone_10_R_(Right_Common_Iliac_Artery)": 18,
"Zone_10_L_(Left_Common_Iliac_Artery)": 19,
"Right_Internal_Iliac_Artery_Dice_Score": 20,
"Left_Internal_Iliac_Artery_Dice_Score": 21,
"Zone_11_R_(Right_External_Iliac_Artery)": 22,
"Zone_11_L_(Left_External_Iliac_Artery)": 23
}
generate_dataset_json(
join(nnUNet_raw, target_dataset_name),
{0: 'CTA'},
labels,
len(cases),
'.nii.gz',
None,
target_dataset_name,
overwrite_image_reader_writer='NibabelIOWithReorient',
reference='https://aortaseg24.grand-challenge.org/',
license='see ref'
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset023_AbdomenAtlas1_1Mini.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
"""
Download the dataset from huggingface:
https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini#3--download-the-dataset
IMPORTANT
cases 5196-9262 currently do not have images, just the segmentation. This seems to be a mistake
"""
base = '/home/isensee/Downloads/AbdomenAtlas/uncompressed'
target_dataset_id = 23
target_dataset_name = f'Dataset{target_dataset_id:03.0f}_AbdomenAtlas1.1Mini'
cases = subdirs(base, join=False, prefix='BDMAP')
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)
for case in cases:
if not isfile(join(base, case, 'ct.nii.gz')):
print(f'Skipping case {case} due to missing image')
continue
shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz'))
shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz'))
class_map = {1: 'aorta', 2: 'gall_bladder', 3: 'kidney_left', 4: 'kidney_right', 5: 'liver',
6: 'pancreas', 7: 'postcava', 8: 'spleen', 9: 'stomach', 10: 'adrenal_gland_left',
11: 'adrenal_gland_right', 12: 'bladder', 13: 'celiac_trunk', 14: 'colon', 15: 'duodenum',
16: 'esophagus', 17: 'femur_left', 18: 'femur_right', 19: 'hepatic_vessel', 20: 'intestine',
21: 'lung_left', 22: 'lung_right', 23: 'portal_vein_and_splenic_vein',
24: 'prostate', 25: 'rectum'}
labels = {
j: i for i, j in class_map.items()
}
labels['background'] = 0
generate_dataset_json(
join(nnUNet_raw, target_dataset_name),
{0: 'CT'},
labels,
len(cases),
'.nii.gz',
None,
target_dataset_name,
overwrite_image_reader_writer='NibabelIOWithReorient',
reference='https://huggingface.co/datasets/AbdomenAtlas/_AbdomenAtlas1.1Mini',
license='Creative Commons Attribution Non Commercial Share Alike 4.0; see reference'
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset027_ACDC.py
================================================
import os
import shutil
from pathlib import Path
from typing import List
from batchgenerators.utilities.file_and_folder_operations import nifti_files, join, maybe_mkdir_p, save_json
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
import numpy as np
def make_out_dirs(dataset_id: int, task_name="ACDC"):
dataset_name = f"Dataset{dataset_id:03d}_{task_name}"
out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name
out_train_dir = out_dir / "imagesTr"
out_labels_dir = out_dir / "labelsTr"
out_test_dir = out_dir / "imagesTs"
os.makedirs(out_dir, exist_ok=True)
os.makedirs(out_train_dir, exist_ok=True)
os.makedirs(out_labels_dir, exist_ok=True)
os.makedirs(out_test_dir, exist_ok=True)
return out_dir, out_train_dir, out_labels_dir, out_test_dir
def create_ACDC_split(labelsTr_folder: str, seed: int = 1234) -> List[dict[str, List]]:
# labelsTr_folder = '/home/isensee/drives/gpu_data_root/OE0441/isensee/nnUNet_raw/nnUNet_raw_remake/Dataset027_ACDC/labelsTr'
nii_files = nifti_files(labelsTr_folder, join=False)
patients = np.unique([i[:len('patient000')] for i in nii_files])
rs = np.random.RandomState(seed)
rs.shuffle(patients)
splits = []
for fold in range(5):
val_patients = patients[fold::5]
train_patients = [i for i in patients if i not in val_patients]
val_cases = [i[:-7] for i in nii_files for j in val_patients if i.startswith(j)]
train_cases = [i[:-7] for i in nii_files for j in train_patients if i.startswith(j)]
splits.append({'train': train_cases, 'val': val_cases})
return splits
def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path):
"""Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases."""
patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()])
patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()])
num_training_cases = 0
# Copy training files and corresponding labels.
for patient_dir in patients_train:
for file in patient_dir.iterdir():
if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name:
# The stem is 'patient.nii', and the suffix is '.gz'.
# We split the stem and append _0000 to the patient part.
shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz")
num_training_cases += 1
elif file.suffix == ".gz" and "_gt" in file.name:
shutil.copy(file, labels_dir / file.name.replace("_gt", ""))
# Copy test files.
for patient_dir in patients_test:
for file in patient_dir.iterdir():
if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name:
shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz")
return num_training_cases
def convert_acdc(src_data_folder: str, dataset_id=27):
out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id)
num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir)
generate_dataset_json(
str(out_dir),
channel_names={
0: "cineMRI",
},
labels={
"background": 0,
"RV": 1,
"MLV": 2,
"LVC": 3,
},
file_ending=".nii.gz",
num_training_cases=num_training_cases,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_folder",
type=str,
help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.",
)
parser.add_argument(
"-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27"
)
args = parser.parse_args()
print("Converting...")
convert_acdc(args.input_folder, args.dataset_id)
dataset_name = f"Dataset{args.dataset_id:03d}_{'ACDC'}"
labelsTr = join(nnUNet_raw, dataset_name, 'labelsTr')
preprocessed_folder = join(nnUNet_preprocessed, dataset_name)
maybe_mkdir_p(preprocessed_folder)
split = create_ACDC_split(labelsTr)
save_json(split, join(preprocessed_folder, 'splits_final.json'), sort_keys=False)
print("Done!")
================================================
FILE: nnunetv2/dataset_conversion/Dataset042_BraTS18.py
================================================
import multiprocessing
import shutil
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
# use this for segmentation only!!!
# nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
img = sitk.ReadImage(in_file)
img_npy = sitk.GetArrayFromImage(img)
uniques = np.unique(img_npy)
for u in uniques:
if u not in [0, 1, 2, 4]:
raise RuntimeError('unexpected label')
seg_new = np.zeros_like(img_npy)
seg_new[img_npy == 4] = 3
seg_new[img_npy == 2] = 1
seg_new[img_npy == 1] = 2
img_corr = sitk.GetImageFromArray(seg_new)
img_corr.CopyInformation(img)
sitk.WriteImage(img_corr, out_file)
def convert_labels_back_to_BraTS(seg: np.ndarray):
new_seg = np.zeros_like(seg)
new_seg[seg == 1] = 2
new_seg[seg == 3] = 4
new_seg[seg == 2] = 1
return new_seg
def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
a = sitk.ReadImage(join(input_folder, filename))
b = sitk.GetArrayFromImage(a)
c = convert_labels_back_to_BraTS(b)
d = sitk.GetImageFromArray(c)
d.CopyInformation(a)
sitk.WriteImage(d, join(output_folder, filename))
def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,
num_processes: int = 12):
"""
reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
"""
maybe_mkdir_p(output_folder)
nii = subfiles(input_folder, suffix='.nii.gz', join=False)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))
if __name__ == '__main__':
brats_data_dir = ...
task_id = 42
task_name = "BraTS2018"
foldername = "Dataset%03.0d_%s" % (task_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='Brats', join=False)
case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="Brats", join=False)
print("copying hggs")
for c in tqdm(case_ids_hgg):
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
generate_dataset_json(out_base,
channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
labels={
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': (3,)
},
num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)),
file_ending='.nii',
regions_class_order=(1, 2, 3),
license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
dataset_release='1.0')
================================================
FILE: nnunetv2/dataset_conversion/Dataset043_BraTS19.py
================================================
import multiprocessing
import shutil
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
# use this for segmentation only!!!
# nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
img = sitk.ReadImage(in_file)
img_npy = sitk.GetArrayFromImage(img)
uniques = np.unique(img_npy)
for u in uniques:
if u not in [0, 1, 2, 4]:
raise RuntimeError('unexpected label')
seg_new = np.zeros_like(img_npy)
seg_new[img_npy == 4] = 3
seg_new[img_npy == 2] = 1
seg_new[img_npy == 1] = 2
img_corr = sitk.GetImageFromArray(seg_new)
img_corr.CopyInformation(img)
sitk.WriteImage(img_corr, out_file)
def convert_labels_back_to_BraTS(seg: np.ndarray):
new_seg = np.zeros_like(seg)
new_seg[seg == 1] = 2
new_seg[seg == 3] = 4
new_seg[seg == 2] = 1
return new_seg
def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
a = sitk.ReadImage(join(input_folder, filename))
b = sitk.GetArrayFromImage(a)
c = convert_labels_back_to_BraTS(b)
d = sitk.GetImageFromArray(c)
d.CopyInformation(a)
sitk.WriteImage(d, join(output_folder, filename))
def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,
num_processes: int = 12):
"""
reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
"""
maybe_mkdir_p(output_folder)
nii = subfiles(input_folder, suffix='.nii.gz', join=False)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))
if __name__ == '__main__':
brats_data_dir = ...
task_id = 43
task_name = "BraTS2019"
foldername = "Dataset%03.0d_%s" % (task_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='BraTS', join=False)
case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="BraTS", join=False)
print("copying hggs")
for c in tqdm(case_ids_hgg):
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
generate_dataset_json(out_base,
channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
labels={
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': (3,)
},
num_training_cases=(len(case_ids_hgg) + len(case_ids_lgg)),
file_ending='.nii',
regions_class_order=(1, 2, 3),
license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
dataset_release='1.0')
================================================
FILE: nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py
================================================
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
import tifffile
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
if __name__ == '__main__':
"""
This is going to be my test dataset for working with tif as input and output images
All we do here is copy the files and rename them. Not file conversions take place
"""
dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM'
imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')
imagests = join(nnUNet_raw, dataset_name, 'imagesTs')
labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')
maybe_mkdir_p(imagestr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelstr)
# we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train
# and Fluo-C3DH-A549-SIM_test
train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train'
test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test'
# with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the
# tif files
# tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file
# that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings
# between files. Important! The spacing must align with the axes.
# Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first.
# The spacing on the website is griven in the wrong axis order. Great.
spacing = (1, 0.126, 0.126)
# train set
for seq in ['01', '02']:
images_dir = join(train_source, seq)
seg_dir = join(train_source, seq + '_GT', 'SEG')
# if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way.
# Simpler filenames in the cell tracking challenge would be soooo nice.
images = subfiles(images_dir, suffix='.tif', sort=True, join=False)
segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False)
for i, (im, se) in enumerate(zip(images, segs)):
target_name = f'{seq}_image_{i:03d}'
# we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input
# channels distributed over separate files
shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif'))
# spacing file!
save_json({'spacing': spacing}, join(imagestr, target_name + '.json'))
shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif'))
# spacing file!
save_json({'spacing': spacing}, join(labelstr, target_name + '.json'))
# test set, same a strain just without the segmentations
for seq in ['01', '02']:
images_dir = join(test_source, seq)
images = subfiles(images_dir, suffix='.tif', sort=True, join=False)
for i, im in enumerate(images):
target_name = f'{seq}_image_{i:03d}'
shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif'))
# spacing file!
save_json({'spacing': spacing}, join(imagests, target_name + '.json'))
# now we generate the dataset json
generate_dataset_json(
join(nnUNet_raw, dataset_name),
{0: 'fluorescence_microscopy'},
{'background': 0, 'cell': 1},
60,
'.tif'
)
# custom split to ensure we are stratifying properly. This dataset only has 2 folds
caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)]
splits = []
splits.append(
{'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]}
)
splits.append(
{'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]}
)
save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json'))
================================================
FILE: nnunetv2/dataset_conversion/Dataset114_MNMs.py
================================================
import csv
import os
import random
from pathlib import Path
import nibabel as nib
from batchgenerators.utilities.file_and_folder_operations import load_json, save_json
from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_preprocessed
def read_csv(csv_file: str):
patient_info = {}
with open(csv_file) as csvfile:
reader = csv.reader(csvfile)
headers = next(reader)
patient_index = headers.index("External code")
ed_index = headers.index("ED")
es_index = headers.index("ES")
vendor_index = headers.index("Vendor")
for row in reader:
patient_info[row[patient_index]] = {
"ed": int(row[ed_index]),
"es": int(row[es_index]),
"vendor": row[vendor_index],
}
return patient_info
# ------------------------------------------------------------------------------
# Conversion to nnUNet format
# ------------------------------------------------------------------------------
def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int):
out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs")
patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()]
patient_info = read_csv(str(src_data_folder / csv_file_name))
save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir)
save_cardiac_phases(patients_test, patient_info, out_test_dir)
# There are non-orthonormal direction cosines in the test and validation data.
# Not sure if the data should be fixed, or we should skip the problematic data.
# patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()]
# save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir)
generate_dataset_json(
str(out_dir),
channel_names={
0: "cineMRI",
},
labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3},
file_ending=".nii.gz",
num_training_cases=len(patients_train) * 2, # 2 since we have ED and ES for each patient
)
def save_cardiac_phases(
patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None
):
for patient in patients:
print(f"Processing patient: {patient.name}")
image = nib.load(patient / f"{patient.name}_sa.nii.gz")
ed_frame = patient_info[patient.name]["ed"]
es_frame = patient_info[patient.name]["es"]
save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient)
if labels_dir:
label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz")
save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient)
def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path):
# Save only extracted diastole and systole slices from the 4D H x W x D x time volume.
image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine)
image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine)
# Labels do not have modality identifiers. Labels always end with 'gt'.
suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz"
nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}"))
nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}"))
# ------------------------------------------------------------------------------
# Create custom splits
# ------------------------------------------------------------------------------
def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25):
existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json")
splits = load_json(existing_splits)
patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
# Filter out any patients not in the training set
patient_info = {
patient: data
for patient, data in read_csv(str(src_data_folder / csv_file)).items()
if patient in patients_train
}
# Get train and validation patients for both vendors
patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"]
patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"]
train_a, val_a = get_vendor_split(patients_a, num_val_patients)
train_b, val_b = get_vendor_split(patients_b, num_val_patients)
# Build filenames from corresponding patient frames
train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]]
train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]]
train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :]
train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :]
val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]]
val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]]
for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]:
# For each train set, we evaluate on A, B and (A + B) respectively
# See table 3 from the original paper for more details.
splits.append({"train": train_set, "val": val_a})
splits.append({"train": train_set, "val": val_b})
splits.append({"train": train_set, "val": val_a + val_b})
save_json(splits, existing_splits)
def get_vendor_split(patients: list[str], num_val_patients: int):
random.shuffle(patients)
total_patients = len(patients)
num_training_patients = total_patients - num_val_patients
return patients[:num_training_patients], patients[num_training_patients:]
if __name__ == "__main__":
import argparse
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
pass
parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter)
parser.add_argument(
"-h",
"--help",
action="help",
default=argparse.SUPPRESS,
help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet "
"format. It can also be used to create additional custom splits, for explicitly training on combinations "
"of vendors A and B (see `--custom-splits`).\n"
"If you wish to generate the custom splits, run the following pipeline:\n\n"
"(1) Run `Dataset114_MNMs -i \n"
"(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n"
"(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n"
"(4) Re-run `Dataset114_MNMs`, with `-s True`.\n"
"(5) Re-run training.\n",
)
parser.add_argument(
"-i",
"--input_folder",
type=str,
default="./data/M&Ms/OpenDataset/",
help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing "
"folders.",
)
parser.add_argument(
"-c",
"--csv_file_name",
type=str,
default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv",
help="The csv file containing the dataset information.",
),
parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.")
parser.add_argument(
"-s",
"--custom_splits",
type=bool,
default=False,
help="Whether to append custom splits for training and testing on different vendors. If True, will create "
"splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out "
"validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from "
"https://arxiv.org/abs/2011.07592 for more info.",
)
args = parser.parse_args()
args.input_folder = Path(args.input_folder)
if args.custom_splits:
print("Appending custom splits...")
create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id)
else:
print("Converting...")
convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id)
print("Done!")
================================================
FILE: nnunetv2/dataset_conversion/Dataset115_EMIDEC.py
================================================
import shutil
from pathlib import Path
from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
def copy_files(src_data_dir: Path, src_test_dir: Path, train_dir: Path, labels_dir: Path, test_dir: Path):
"""Copy files from the EMIDEC dataset to the nnUNet dataset folder. Returns the number of training cases."""
patients_train = sorted([f for f in src_data_dir.iterdir() if f.is_dir()])
patients_test = sorted([f for f in src_test_dir.iterdir() if f.is_dir()])
# Copy training files and corresponding labels.
for patient in patients_train:
train_file = patient / "Images" / f"{patient.name}.nii.gz"
label_file = patient / "Contours" / f"{patient.name}.nii.gz"
shutil.copy(train_file, train_dir / f"{train_file.stem.split('.')[0]}_0000.nii.gz")
shutil.copy(label_file, labels_dir)
# Copy test files.
for patient in patients_test:
test_file = patient / "Images" / f"{patient.name}.nii.gz"
shutil.copy(test_file, test_dir / f"{test_file.stem.split('.')[0]}_0000.nii.gz")
return len(patients_train)
def convert_emidec(src_data_dir: str, src_test_dir: str, dataset_id=27):
out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id, task_name="EMIDEC")
num_training_cases = copy_files(Path(src_data_dir), Path(src_test_dir), train_dir, labels_dir, test_dir)
generate_dataset_json(
str(out_dir),
channel_names={
0: "cineMRI",
},
labels={
"background": 0,
"cavity": 1,
"normal_myocardium": 2,
"myocardial_infarction": 3,
"no_reflow": 4,
},
file_ending=".nii.gz",
num_training_cases=num_training_cases,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_dir", type=str, help="The EMIDEC dataset directory.")
parser.add_argument("-t", "--test_dir", type=str, help="The EMIDEC test set directory.")
parser.add_argument(
"-d", "--dataset_id", required=False, type=int, default=115, help="nnU-Net Dataset ID, default: 115"
)
args = parser.parse_args()
print("Converting...")
convert_emidec(args.input_dir, args.test_dir, args.dataset_id)
print("Done!")
================================================
FILE: nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py
================================================
from typing import Dict, Any
import os
from os.path import join
import json
import random
import multiprocessing
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
def mapping_DS119() -> Dict[int, int]:
"""Remove all NA Classes and make Class IDs continuous"""
mapping = {}
mapping.update({i: i for i in range(1, 19)}) # [1-10]->[1-10] | [11-18]->[11-18]
mapping.update({i: i - 2 for i in range(21, 29)}) # [21-28]->[19-26]
mapping.update({i: i - 4 for i in range(31, 39)}) # [31-38]->[27-34]
mapping.update({i: i - 6 for i in range(41, 49)}) # [41-48]->[35-42]
return mapping
def mapping_DS120() -> Dict[int, int]:
"""Remove Only Keep Teeth and Jaw Classes"""
mapping = {}
mapping.update({i: i for i in range(1, 3)}) # [0-2] -> [0-2]
mapping.update({i: i - 8 for i in range(11, 19)}) # [11-18]->[3-10]
mapping.update({i: i - 10 for i in range(21, 29)}) # [21-28]->[11-18]
mapping.update({i: i - 12 for i in range(31, 39)}) # [31-38]->[19-26]
mapping.update({i: i - 14 for i in range(41, 49)}) # [41-48]->[27-34]
return mapping
def mapping_DS121() -> Dict[int, int]:
"""Remove Only Keep Teeth and Jaw Classes"""
mapping = {}
mapping.update({i: i - 10 for i in range(11, 19)}) # [11-18]->[3-8]
mapping.update({i: i - 12 for i in range(21, 29)}) # [21-28]->[11-16]
mapping.update({i: i - 14 for i in range(31, 39)}) # [31-38]->[19-24]
mapping.update({i: i - 16 for i in range(41, 49)}) # [41-48]->[27-32]
return mapping
def load_json(json_file: str) -> Any:
with open(json_file, "r") as f:
data = json.load(f)
return data
def write_json(json_file: str, data: Any, indent: int = 4) -> None:
with open(json_file, "w") as f:
json.dump(data, f, indent=indent)
def image_to_nifi(input_path: str, output_path: str) -> None:
image_sitk = sitk.ReadImage(input_path)
sitk.WriteImage(image_sitk, output_path)
def label_mapping(input_path: str, output_path: str, mapping: Dict[int, int] = None) -> None:
label_sitk = sitk.ReadImage(input_path)
if mapping is not None:
label_np = sitk.GetArrayFromImage(label_sitk)
label_np_new = np.zeros_like(label_np, dtype=np.uint8)
for org_id, new_id in mapping.items():
label_np_new[label_np == org_id] = new_id
label_sitk_new = sitk.GetImageFromArray(label_np_new)
label_sitk_new.CopyInformation(label_sitk)
sitk.WriteImage(label_sitk_new, output_path)
else:
sitk.WriteImage(label_sitk, output_path)
def process_images(files: str, img_dir_in: str, img_dir_out: str, n_processes: int = 12):
os.makedirs(img_dir_out, exist_ok=True)
iterable = [
{
"input_path": join(img_dir_in, file),
"output_path": join(img_dir_out, file.replace(".mha", ".nii.gz")),
}
for file in files
]
with multiprocessing.Pool(processes=n_processes) as pool:
jobs = [pool.apply_async(image_to_nifi, kwds={**args}) for args in iterable]
_ = [job.get() for job in tqdm(jobs, desc="Process Images")]
def process_labels(
files: str, lbl_dir_in: str, lbl_dir_out: str, mapping: Dict[int, int], n_processes: int = 12
) -> None:
os.makedirs(lbl_dir_out, exist_ok=True)
iterable = [
{
"input_path": join(lbl_dir_in, file),
"output_path": join(lbl_dir_out, file.replace(".mha", ".nii.gz")),
"mapping": mapping,
}
for file in files
]
with multiprocessing.Pool(processes=n_processes) as pool:
jobs = [pool.apply_async(label_mapping, kwds={**args}) for args in iterable]
_ = [job.get() for job in tqdm(jobs, desc="Process Labels...")]
def process_ds(
root: str, input_ds: str, output_ds: str, mapping: dict, image_link: str = None
) -> None:
os.makedirs(join(root, output_ds), exist_ok=True)
os.makedirs(join(root, output_ds, "labelsTr"), exist_ok=True)
# --- Handle Labels --- #
lbl_files = os.listdir(join(root, input_ds, "labelsTr"))
lbl_dir_in = join(root, input_ds, "labelsTr")
lbl_dir_out = join(root, output_ds, "labelsTr")
process_labels(lbl_files, lbl_dir_in, lbl_dir_out, mapping, n_processes=12)
# --- Handle Images --- #
img_files = os.listdir(join(root, input_ds, "imagesTr"))
dataset = {}
if image_link is None:
img_dir_in = join(root, input_ds, "imagesTr")
img_dir_out = join(root, output_ds, "imagesTr")
process_images(img_files, img_dir_in, img_dir_out, n_processes=12)
else:
base_name = [file.replace("_0000.mha", "") for file in img_files]
for name in base_name:
dataset[name] = {
"images": [join("..", image_link, "imagesTr", name + "_0000.nii.gz")],
"label": join("labelsTr", name + ".nii.gz"),
}
# --- Generate dataset.json --- #
dataset_json = load_json(join(root, input_ds, "dataset.json"))
dataset_json["file_ending"] = ".nii.gz"
dataset_json["name"] = output_ds
dataset_json["numTraining"] = len(lbl_files)
if dataset != {}:
dataset_json["dataset"] = dataset
label_dict = dataset_json["labels"]
label_dict_new = {"background": 0}
for k, v in label_dict.items():
if v in mapping.keys():
label_dict_new[k] = mapping[v]
dataset_json["labels"] = label_dict_new
write_json(join(root, output_ds, "dataset.json"), dataset_json)
# --- Generate splits_final.json --- #
img_names = [file.replace("_0000.mha", "") for file in img_files]
random_seed = 42
random.seed(random_seed)
random.shuffle(img_names)
split_index = int(len(img_names) * 0.7) # 70:30 split
train_files = img_names[:split_index]
val_files = img_names[split_index:]
train_files.sort()
val_files.sort()
split = [{"train": train_files, "val": val_files}]
write_json(join(root, output_ds, "splits_final.json"), split)
if __name__ == "__main__":
# Different nnUNet Datasets
# Dataset 112: Raw
# Dataset 119: Replace NaN classes
# Dataset 120: Only Teeth + Jaw Classes
# Dataset 121: Only Teeth Classes
root = "/media/l727r/data/Teeth_Data/ToothFairy2_Dataset"
process_ds(root, "Dataset112_ToothFairy2", "Dataset119_ToothFairy2_All", mapping_DS119(), None)
# process_ds(
# root,
# "Dataset112_ToothFairy2",
# "Dataset120_ToothFairy2_JawTeeth",
# mapping_DS120(),
# "Dataset119_ToothFairy2_All",
# )
# process_ds(
# root,
# "Dataset112_ToothFairy2",
# "Dataset121_ToothFairy2_Teeth",
# mapping_DS121(),
# "Dataset119_ToothFairy2_All",
# )
================================================
FILE: nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py
================================================
import multiprocessing
import shutil
from multiprocessing import Pool
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
from skimage import io
from acvl_utils.morphology.morphology_helper import generic_filter_components
from scipy.ndimage import binary_fill_holes
def load_and_convert_case(input_image: str, input_seg: str, output_image: str, output_seg: str,
min_component_size: int = 50):
seg = io.imread(input_seg)
seg[seg == 255] = 1
image = io.imread(input_image)
image = image.sum(2)
mask = image == (3 * 255)
# the dataset has large white areas in which road segmentations can exist but no image information is available.
# Remove the road label in these areas
mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if
sizes[j] > min_component_size])
mask = binary_fill_holes(mask)
seg[mask] = 0
io.imsave(output_seg, seg, check_contrast=False)
shutil.copy(input_image, output_image)
if __name__ == "__main__":
# extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download
source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal'
dataset_name = 'Dataset120_RoadSegmentation'
imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')
imagests = join(nnUNet_raw, dataset_name, 'imagesTs')
labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')
labelsts = join(nnUNet_raw, dataset_name, 'labelsTs')
maybe_mkdir_p(imagestr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelstr)
maybe_mkdir_p(labelsts)
train_source = join(source, 'training')
test_source = join(source, 'testing')
with multiprocessing.get_context("spawn").Pool(8) as p:
# not all training images have a segmentation
valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png')
num_train = len(valid_ids)
r = []
for v in valid_ids:
r.append(
p.starmap_async(
load_and_convert_case,
((
join(train_source, 'input', v),
join(train_source, 'output', v),
join(imagestr, v[:-4] + '_0000.png'),
join(labelstr, v),
50
),)
)
)
# test set
valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png')
for v in valid_ids:
r.append(
p.starmap_async(
load_and_convert_case,
((
join(test_source, 'input', v),
join(test_source, 'output', v),
join(imagests, v[:-4] + '_0000.png'),
join(labelsts, v),
50
),)
)
)
_ = [i.get() for i in r]
generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1},
num_train, '.png', dataset_name=dataset_name)
================================================
FILE: nnunetv2/dataset_conversion/Dataset137_BraTS21.py
================================================
import multiprocessing
import shutil
import SimpleITK as sitk
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
# use this for segmentation only!!!
# nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
img = sitk.ReadImage(in_file)
img_npy = sitk.GetArrayFromImage(img)
uniques = np.unique(img_npy)
for u in uniques:
if u not in [0, 1, 2, 4]:
raise RuntimeError('unexpected label')
seg_new = np.zeros_like(img_npy)
seg_new[img_npy == 4] = 3
seg_new[img_npy == 2] = 1
seg_new[img_npy == 1] = 2
img_corr = sitk.GetImageFromArray(seg_new)
img_corr.CopyInformation(img)
sitk.WriteImage(img_corr, out_file)
def convert_labels_back_to_BraTS(seg: np.ndarray):
new_seg = np.zeros_like(seg)
new_seg[seg == 1] = 2
new_seg[seg == 3] = 4
new_seg[seg == 2] = 1
return new_seg
def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
a = sitk.ReadImage(join(input_folder, filename))
b = sitk.GetArrayFromImage(a)
c = convert_labels_back_to_BraTS(b)
d = sitk.GetImageFromArray(c)
d.CopyInformation(a)
sitk.WriteImage(d, join(output_folder, filename))
def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12):
"""
reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
"""
maybe_mkdir_p(output_folder)
nii = subfiles(input_folder, suffix='.nii.gz', join=False)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))
if __name__ == '__main__':
brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training'
task_id = 137
task_name = "BraTS2021"
foldername = "Dataset%03.0d_%s" % (task_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False)
for c in case_ids:
shutil.copy(join(brats_data_dir, c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
shutil.copy(join(brats_data_dir, c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
shutil.copy(join(brats_data_dir, c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
shutil.copy(join(brats_data_dir, c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + "_seg.nii.gz"),
join(labelstr, c + '.nii.gz'))
generate_dataset_json(out_base,
channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
labels={
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': (3, )
},
num_training_cases=len(case_ids),
file_ending='.nii.gz',
regions_class_order=(1, 2, 3),
license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
dataset_release='1.0')
================================================
FILE: nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218):
"""
AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into
the train set. Having a 5-fold cross-validation is superior to a single train:val split
"""
task_name = "AMOS2022_postChallenge_task1"
foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
imagests = join(out_base, "imagesTs")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelstr)
dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))
training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]
tr_ctr = 0
for tr in training_identifiers:
if int(tr.split("_")[-1]) <= 410: # these are the CT images
tr_ctr += 1
shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]
for ts in test_identifiers:
if int(ts.split("_")[-1]) <= 500: # these are the CT images
shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))
val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]
for vl in val_identifiers:
if int(vl.split("_")[-1]) <= 409: # these are the CT images
tr_ctr += 1
shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))
shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))
generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},
num_training_cases=tr_ctr, file_ending='.nii.gz',
dataset_name=task_name, reference='https://amos22.grand-challenge.org/',
release='https://zenodo.org/record/7262581',
overwrite_image_reader_writer='NibabelIOWithReorient',
description="This is the dataset as released AFTER the challenge event. It has the "
"validation set gt in it! We just use the validation images as additional "
"training cases because AMOS doesn't specify how they should be used. nnU-Net's"
" 5-fold CV is better than some random train:val split.")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_folder', type=str,
help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. "
"Use this link: https://zenodo.org/record/7262581."
"You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!")
parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218')
args = parser.parse_args()
amos_base = args.input_folder
convert_amos_task1(amos_base, args.d)
================================================
FILE: nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219):
"""
AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into
the train set. Having a 5-fold cross-validation is superior to a single train:val split
"""
task_name = "AMOS2022_postChallenge_task2"
foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
imagests = join(out_base, "imagesTs")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(imagests)
maybe_mkdir_p(labelstr)
dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))
training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]
for tr in training_identifiers:
shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]
for ts in test_identifiers:
shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))
val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]
for vl in val_identifiers:
shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))
shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))
generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},
num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz',
dataset_name=task_name, reference='https://amos22.grand-challenge.org/',
release='https://zenodo.org/record/7262581',
overwrite_image_reader_writer='NibabelIOWithReorient',
description="This is the dataset as released AFTER the challenge event. It has the "
"validation set gt in it! We just use the validation images as additional "
"training cases because AMOS doesn't specify how they should be used. nnU-Net's"
" 5-fold CV is better than some random train:val split.")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_folder', type=str,
help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. "
"Use this link: https://zenodo.org/record/7262581."
"You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!")
parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219')
args = parser.parse_args()
amos_base = args.input_folder
convert_amos_task2(amos_base, args.d)
# /home/isensee/Downloads/amos22/amos22/
================================================
FILE: nnunetv2/dataset_conversion/Dataset220_KiTS2023.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220):
task_name = "KiTS2023"
foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
cases = subdirs(kits_base_dir, prefix='case_', join=False)
for tr in cases:
shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
generate_dataset_json(out_base, {0: "CT"},
labels={
"background": 0,
"kidney": (1, 2, 3),
"masses": (2, 3),
"tumor": 2
},
regions_class_order=(1, 3, 2),
num_training_cases=len(cases), file_ending='.nii.gz',
dataset_name=task_name, reference='none',
release='0.1.3',
overwrite_image_reader_writer='NibabelIOWithReorient',
description="KiTS2023")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_folder', type=str,
help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)")
parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220')
args = parser.parse_args()
amos_base = args.input_folder
convert_kits2023(amos_base, args.d)
# /media/isensee/raw_data/raw_datasets/kits23/dataset
================================================
FILE: nnunetv2/dataset_conversion/Dataset221_AutoPETII_2023.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
def convert_autopet(autopet_base_dir:str = '/media/isensee/My Book1/AutoPET/nifti/FDG-PET-CT-Lesions',
nnunet_dataset_id: int = 221):
task_name = "AutoPETII_2023"
foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
patients = subdirs(autopet_base_dir, prefix='PETCT', join=False)
n = 0
identifiers = []
for pat in patients:
patient_acquisitions = subdirs(join(autopet_base_dir, pat), join=False)
for pa in patient_acquisitions:
n += 1
identifier = f"{pat}_{pa}"
identifiers.append(identifier)
if not isfile(join(imagestr, f'{identifier}_0000.nii.gz')):
shutil.copy(join(autopet_base_dir, pat, pa, 'CTres.nii.gz'), join(imagestr, f'{identifier}_0000.nii.gz'))
if not isfile(join(imagestr, f'{identifier}_0001.nii.gz')):
shutil.copy(join(autopet_base_dir, pat, pa, 'SUV.nii.gz'), join(imagestr, f'{identifier}_0001.nii.gz'))
if not isfile(join(imagestr, f'{identifier}.nii.gz')):
shutil.copy(join(autopet_base_dir, pat, pa, 'SEG.nii.gz'), join(labelstr, f'{identifier}.nii.gz'))
generate_dataset_json(out_base, {0: "CT", 1:"CT"},
labels={
"background": 0,
"tumor": 1
},
num_training_cases=n, file_ending='.nii.gz',
dataset_name=task_name, reference='https://autopet-ii.grand-challenge.org/',
release='release',
# overwrite_image_reader_writer='NibabelIOWithReorient',
description=task_name)
# manual split
splits = []
for fold in range(5):
val_patients = patients[fold :: 5]
splits.append(
{
'train': [i for i in identifiers if not any([i.startswith(v) for v in val_patients])],
'val': [i for i in identifiers if any([i.startswith(v) for v in val_patients])],
}
)
pp_out_dir = join(nnUNet_preprocessed, foldername)
maybe_mkdir_p(pp_out_dir)
save_json(splits, join(pp_out_dir, 'splits_final.json'), sort_keys=False)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_folder', type=str,
help="The downloaded and extracted autopet dataset (must have PETCT_XXX subfolders)")
parser.add_argument('-d', required=False, type=int, default=221, help='nnU-Net Dataset ID, default: 221')
args = parser.parse_args()
amos_base = args.input_folder
convert_autopet(amos_base, args.d)
================================================
FILE: nnunetv2/dataset_conversion/Dataset223_AMOS2022postChallenge.py
================================================
import shutil
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_raw
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
if __name__ == '__main__':
downloaded_amos_dir = '/home/isensee/amos22/amos22' # downloaded and extracted from https://zenodo.org/record/7155725#.Y0OOCOxBztM
target_dataset_id = 223
target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AMOS2022postChallenge'
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
imagesTs = join(nnUNet_raw, target_dataset_name, 'imagesTs')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(imagesTs)
maybe_mkdir_p(labelsTr)
train_identifiers = []
# copy images
source = join(downloaded_amos_dir, 'imagesTr')
source_files = nifti_files(source, join=False)
train_identifiers += source_files
for s in source_files:
shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz'))
source = join(downloaded_amos_dir, 'imagesVa')
source_files = nifti_files(source, join=False)
train_identifiers += source_files
for s in source_files:
shutil.copy(join(source, s), join(imagesTr, s[:-7] + '_0000.nii.gz'))
source = join(downloaded_amos_dir, 'imagesTs')
source_files = nifti_files(source, join=False)
for s in source_files:
shutil.copy(join(source, s), join(imagesTs, s[:-7] + '_0000.nii.gz'))
# copy labels
source = join(downloaded_amos_dir, 'labelsTr')
source_files = nifti_files(source, join=False)
for s in source_files:
shutil.copy(join(source, s), join(labelsTr, s))
source = join(downloaded_amos_dir, 'labelsVa')
source_files = nifti_files(source, join=False)
for s in source_files:
shutil.copy(join(source, s), join(labelsTr, s))
old_dataset_json = load_json(join(downloaded_amos_dir, 'dataset.json'))
new_labels = {v: k for k, v in old_dataset_json['labels'].items()}
generate_dataset_json(join(nnUNet_raw, target_dataset_name), {0: 'nonCT'}, new_labels,
num_training_cases=len(train_identifiers), file_ending='.nii.gz', regions_class_order=None,
dataset_name=target_dataset_name, reference='https://zenodo.org/record/7155725#.Y0OOCOxBztM',
license=old_dataset_json['licence'], # typo in OG dataset.json
description=old_dataset_json['description'],
release=old_dataset_json['release'])
================================================
FILE: nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py
================================================
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
if __name__ == '__main__':
"""
How to train our submission to the JHU benchmark
1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system!
2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl
nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process
will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with
a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed
constraints.
3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required,
training will take ~28-30h.
"""
base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini'
cases = subdirs(base, join=False, prefix='BDMAP')
target_dataset_id = 224
target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AbdomenAtlas1.0'
raw_dir = '/home/isensee/drives/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/2024_JHU_benchmark'
maybe_mkdir_p(join(raw_dir, target_dataset_name))
imagesTr = join(raw_dir, target_dataset_name, 'imagesTr')
labelsTr = join(raw_dir, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)
for case in cases:
shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz'))
shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz'))
labels = {
"background": 0,
"aorta": 1,
"gall_bladder": 2,
"kidney_left": 3,
"kidney_right": 4,
"liver": 5,
"pancreas": 6,
"postcava": 7,
"spleen": 8,
"stomach": 9
}
generate_dataset_json(
join(raw_dir, target_dataset_name),
{0: 'nonCT'}, # this was a mistake we did at the beginning and we keep it like that here for consistency
labels,
len(cases),
'.nii.gz',
None,
target_dataset_name,
overwrite_image_reader_writer='NibabelIOWithReorient'
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset226_BraTS2024-BraTS-GLI.py
================================================
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from batchgenerators.utilities.file_and_folder_operations import join, subdirs, subfiles, maybe_mkdir_p
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
"""
this dataset does not copy the data into nnunet format and just links to existing data. The dataset can only be
used from one machine because the paths in the dataset.json are hard coded
"""
extracted_BraTS2024_GLI_dir = '/home/isensee/BraTS2024_traindata/training_data1'
nnunet_dataset_name = 'BraTS2024-BraTS-GLI'
nnunet_dataset_id = 226
dataset_name = f'Dataset{nnunet_dataset_id:03d}_{nnunet_dataset_name}'
dataset_dir = join(nnUNet_raw, dataset_name)
maybe_mkdir_p(dataset_dir)
dataset = {}
casenames = subdirs(extracted_BraTS2024_GLI_dir, join=False)
for c in casenames:
dataset[c] = {
'label': join(extracted_BraTS2024_GLI_dir, c, c + '-seg.nii.gz'),
'images': [
join(extracted_BraTS2024_GLI_dir, c, c + '-t1n.nii.gz'),
join(extracted_BraTS2024_GLI_dir, c, c + '-t1c.nii.gz'),
join(extracted_BraTS2024_GLI_dir, c, c + '-t2w.nii.gz'),
join(extracted_BraTS2024_GLI_dir, c, c + '-t2f.nii.gz')
]
}
labels = {
'background': 0,
'NETC': 1,
'SNFH': 2,
'ET': 3,
'RC': 4,
}
generate_dataset_json(
dataset_dir,
{
0: 'T1',
1: "T1C",
2: "T2W",
3: "T2F"
},
labels,
num_training_cases=len(dataset),
file_ending='.nii.gz',
regions_class_order=None,
dataset_name=dataset_name,
reference='https://www.synapse.org/Synapse:syn53708249/wiki/627500',
license='see https://www.synapse.org/Synapse:syn53708249/wiki/627508',
dataset=dataset,
description='This dataset does not copy the data into nnunet format and just links to existing data. '
'The dataset can only be used from one machine because the paths in the dataset.json are hard coded'
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset227_TotalSegmentatorMRI.py
================================================
import SimpleITK
import nibabel
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
base = '/home/isensee/Downloads/TotalsegmentatorMRI_dataset_v100'
cases = subdirs(base, join=False)
target_dataset_id = 227
target_dataset_name = f'Dataset{target_dataset_id:3.0f}_TotalSegmentatorMRI'
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
imagesTr = join(nnUNet_raw, target_dataset_name, 'imagesTr')
labelsTr = join(nnUNet_raw, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)
# discover labels
label_fnames = nifti_files(join(base, cases[0], 'segmentations'), join=False)
label_dict = {i[:-7]: j + 1 for j, i in enumerate(label_fnames)}
labelnames = list(label_dict.keys())
label_dict['background'] = 0
for case in cases:
img = nibabel.load(join(base, case, 'mri.nii.gz'))
nibabel.save(img, join(imagesTr, case + '_0000.nii.gz'))
seg_nib = nibabel.load(join(base, case, 'segmentations', labelnames[0] + '.nii.gz'))
init_seg_npy = np.asanyarray(seg_nib.dataobj)
init_seg_npy[init_seg_npy > 0] = label_dict[labelnames[0]]
for labelname in labelnames[1:]:
seg = nibabel.load(join(base, case, 'segmentations', labelname + '.nii.gz'))
seg = np.asanyarray(seg.dataobj)
init_seg_npy[seg > 0] = label_dict[labelname]
out = nibabel.Nifti1Image(init_seg_npy, affine=seg_nib.affine, header=seg_nib.header)
nibabel.save(out, join(labelsTr, case + '.nii.gz'))
generate_dataset_json(
join(nnUNet_raw, target_dataset_name),
{0: 'MRI'}, # this was a mistake we did at the beginning and we keep it like that here for consistency
label_dict,
len(cases),
'.nii.gz',
None,
target_dataset_name,
overwrite_image_reader_writer='NibabelIOWithReorient',
release='1.0.0',
reference='https://zenodo.org/records/11367005',
license='see reference'
)
================================================
FILE: nnunetv2/dataset_conversion/Dataset987_dummyDataset4.py
================================================
import os
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
if __name__ == '__main__':
# creates a dummy dataset where there are no files in imagestr and labelstr
source_dataset = 'Dataset004_Hippocampus'
target_dataset = 'Dataset987_dummyDataset4'
target_dataset_dir = join(nnUNet_raw, target_dataset)
maybe_mkdir_p(target_dataset_dir)
dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset))
# the returned dataset will have absolute paths. We should use relative paths so that you can freely copy
# datasets around between systems. As long as the source dataset is there it will continue working even if
# nnUNet_raw is in different locations
# paths must be relative to target_dataset_dir!!!
for k in dataset.keys():
dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir)
dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']]
# load old dataset.json
dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json'))
dataset_json['dataset'] = dataset
# save
save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False)
================================================
FILE: nnunetv2/dataset_conversion/Dataset989_dummyDataset4_2.py
================================================
import os
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
if __name__ == '__main__':
# creates a dummy dataset where there are no files in imagestr and labelstr
source_dataset = 'Dataset004_Hippocampus'
target_dataset = 'Dataset989_dummyDataset4_2'
target_dataset_dir = join(nnUNet_raw, target_dataset)
maybe_mkdir_p(target_dataset_dir)
dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset))
# the returned dataset will have absolute paths. We should use relative paths so that you can freely copy
# datasets around between systems. As long as the source dataset is there it will continue working even if
# nnUNet_raw is in different locations
# paths must be relative to target_dataset_dir!!!
for k in dataset.keys():
dataset[k]['label'] = join('$nnUNet_raw', os.path.relpath(dataset[k]['label'], nnUNet_raw))
dataset[k]['images'] = [join('$nnUNet_raw', os.path.relpath(i, nnUNet_raw)) for i in dataset[k]['images']]
# load old dataset.json
dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json'))
dataset_json['dataset'] = dataset
# save
save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False)
================================================
FILE: nnunetv2/dataset_conversion/__init__.py
================================================
================================================
FILE: nnunetv2/dataset_conversion/convert_MSD_dataset.py
================================================
import argparse
import multiprocessing
import shutil
from typing import Optional
import SimpleITK as sitk
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets
from nnunetv2.configuration import default_num_processes
import numpy as np
def split_4d_nifti(filename, output_folder):
img_itk = sitk.ReadImage(filename)
dim = img_itk.GetDimension()
file_base = os.path.basename(filename)
if dim == 3:
shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
return
elif dim != 4:
raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
else:
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing()
origin = img_itk.GetOrigin()
direction = np.array(img_itk.GetDirection()).reshape(4,4)
# now modify these to remove the fourth dimension
spacing = tuple(list(spacing[:-1]))
origin = tuple(list(origin[:-1]))
direction = tuple(direction[:-1, :-1].reshape(-1))
for i, t in enumerate(range(img_npy.shape[0])):
img = img_npy[t]
img_itk_new = sitk.GetImageFromArray(img)
img_itk_new.SetSpacing(spacing)
img_itk_new.SetOrigin(origin)
img_itk_new.SetDirection(direction)
sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))
def convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None,
num_processes: int = default_num_processes) -> None:
if source_folder.endswith('/') or source_folder.endswith('\\'):
source_folder = source_folder[:-1]
labelsTr = join(source_folder, 'labelsTr')
imagesTs = join(source_folder, 'imagesTs')
imagesTr = join(source_folder, 'imagesTr')
assert isdir(labelsTr), f"labelsTr subfolder missing in source folder"
assert isdir(imagesTs), f"imagesTs subfolder missing in source folder"
assert isdir(imagesTr), f"imagesTr subfolder missing in source folder"
dataset_json = join(source_folder, 'dataset.json')
assert isfile(dataset_json), f"dataset.json missing in source_folder"
# infer source dataset id and name
task, dataset_name = os.path.basename(source_folder).split('_')
task_id = int(task[4:])
# check if target dataset id is taken
target_id = task_id if overwrite_target_id is None else overwrite_target_id
existing_datasets = find_candidate_datasets(target_id)
assert len(existing_datasets) == 0, f"Target dataset id {target_id} is already taken, please consider changing " \
f"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)"
target_dataset_name = f"Dataset{target_id:03d}_{dataset_name}"
target_folder = join(nnUNet_raw, target_dataset_name)
target_imagesTr = join(target_folder, 'imagesTr')
target_imagesTs = join(target_folder, 'imagesTs')
target_labelsTr = join(target_folder, 'labelsTr')
maybe_mkdir_p(target_imagesTr)
maybe_mkdir_p(target_imagesTs)
maybe_mkdir_p(target_labelsTr)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
results = []
# convert 4d train images
source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if
not i.startswith('.') and not i.startswith('_')]
source_images = [join(imagesTr, i) for i in source_images]
results.append(
p.starmap_async(
split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images))
)
)
# convert 4d test images
source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if
not i.startswith('.') and not i.startswith('_')]
source_images = [join(imagesTs, i) for i in source_images]
results.append(
p.starmap_async(
split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images))
)
)
# copy segmentations
source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if
not i.startswith('.') and not i.startswith('_')]
for s in source_images:
shutil.copy(join(labelsTr, s), join(target_labelsTr, s))
[i.get() for i in results]
dataset_json = load_json(dataset_json)
dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}
dataset_json['file_ending'] = ".nii.gz"
dataset_json["channel_names"] = dataset_json["modality"]
del dataset_json["modality"]
del dataset_json["training"]
del dataset_json["test"]
save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)
def entry_point():
parser = argparse.ArgumentParser()
parser.add_argument('-i', type=str, required=True,
help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: '
'/home/fabian/Downloads/Task05_Prostate')
parser.add_argument('-overwrite_id', type=int, required=False, default=None,
help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from '
'folder name). Only use this if you already have an equivalently numbered dataset!')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f'Number of processes used. Default: {default_num_processes}')
args = parser.parse_args()
convert_msd_dataset(args.i, args.overwrite_id, args.np)
if __name__ == '__main__':
convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201)
================================================
FILE: nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py
================================================
import shutil
from copy import deepcopy
from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json
from nnunetv2.paths import nnUNet_raw
def convert(source_folder, target_dataset_name):
"""
remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY
source_folder
"""
if isdir(join(nnUNet_raw, target_dataset_name)):
raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... '
f'(we might break something). If you are sure you want to proceed, please manually '
f'delete {join(nnUNet_raw, target_dataset_name)}')
maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr'))
shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr'))
if isdir(join(source_folder, 'imagesTs')):
shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs'))
if isdir(join(source_folder, 'labelsTs')):
shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs'))
if isdir(join(source_folder, 'imagesVal')):
shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal'))
if isdir(join(source_folder, 'labelsVal')):
shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal'))
shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name))
dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json'))
del dataset_json['tensorImageSize']
del dataset_json['numTest']
del dataset_json['training']
del dataset_json['test']
dataset_json['channel_names'] = deepcopy(dataset_json['modality'])
del dataset_json['modality']
dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}
dataset_json['file_ending'] = ".nii.gz"
save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)
def convert_entry_point():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("input_folder", type=str,
help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! '
'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not '
'know where v1 tasks are.')
parser.add_argument("output_dataset_name", type=str,
help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!')
args = parser.parse_args()
convert(args.input_folder, args.output_dataset_name)
================================================
FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
================================================
import SimpleITK as sitk
import shutil
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray:
assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!"
seg_new = np.ones_like(seg) * label_manager.ignore_label
x, y, z = seg.shape
# x
num_slices = max(1, round(x * percent_of_slices))
selected_slices = np.random.choice(x, num_slices, replace=False)
seg_new[selected_slices] = seg[selected_slices]
# y
num_slices = max(1, round(y * percent_of_slices))
selected_slices = np.random.choice(y, num_slices, replace=False)
seg_new[:, selected_slices] = seg[:, selected_slices]
# z
num_slices = max(1, round(z * percent_of_slices))
selected_slices = np.random.choice(z, num_slices, replace=False)
seg_new[:, :, selected_slices] = seg[:, :, selected_slices]
return seg_new
if __name__ == '__main__':
dataset_name = 'IntegrationTest_Hippocampus_regions_ignore'
dataset_id = 996
dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
try:
existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
if existing_dataset_name != dataset_name:
raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
f"nnUNet_results!")
except RuntimeError:
pass
if isdir(join(nnUNet_raw, dataset_name)):
shutil.rmtree(join(nnUNet_raw, dataset_name))
source_dataset = maybe_convert_to_dataset_name(4)
shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
# additionally optimize entire hippocampus region, remove Posterior
dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
dj['labels'] = {
'background': 0,
'hippocampus': (1, 2),
'anterior': 1,
'ignore': 3
}
dj['regions_class_order'] = (2, 1)
save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
# now add ignore label to segmentation images
np.random.seed(1234)
lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order'))
segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr'))
for s in segs:
seg_itk = sitk.ReadImage(s)
seg_npy = sitk.GetArrayFromImage(seg_itk)
seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3)
seg_itk_new = sitk.GetImageFromArray(seg_npy)
seg_itk_new.CopyInformation(seg_itk)
sitk.WriteImage(seg_itk_new, s)
================================================
FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py
================================================
import shutil
from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
dataset_name = 'IntegrationTest_Hippocampus_regions'
dataset_id = 997
dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
try:
existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
if existing_dataset_name != dataset_name:
raise FileExistsError(
f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
f"nnUNet_results!")
except RuntimeError:
pass
if isdir(join(nnUNet_raw, dataset_name)):
shutil.rmtree(join(nnUNet_raw, dataset_name))
source_dataset = maybe_convert_to_dataset_name(4)
shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
# additionally optimize entire hippocampus region, remove Posterior
dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
dj['labels'] = {
'background': 0,
'hippocampus': (1, 2),
'anterior': 1
}
dj['regions_class_order'] = (2, 1)
save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
================================================
FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py
================================================
import shutil
from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
dataset_name = 'IntegrationTest_Hippocampus_ignore'
dataset_id = 998
dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
try:
existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
if existing_dataset_name != dataset_name:
raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
f"nnUNet_results!")
except RuntimeError:
pass
if isdir(join(nnUNet_raw, dataset_name)):
shutil.rmtree(join(nnUNet_raw, dataset_name))
source_dataset = maybe_convert_to_dataset_name(4)
shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
# set class 2 to ignore label
dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
dj['labels']['ignore'] = 2
del dj['labels']['Posterior']
save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
================================================
FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py
================================================
import shutil
from batchgenerators.utilities.file_and_folder_operations import isdir, join
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.paths import nnUNet_raw
if __name__ == '__main__':
dataset_name = 'IntegrationTest_Hippocampus'
dataset_id = 999
dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
try:
existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
if existing_dataset_name != dataset_name:
raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
f"nnUNet_results!")
except RuntimeError:
pass
if isdir(join(nnUNet_raw, dataset_name)):
shutil.rmtree(join(nnUNet_raw, dataset_name))
source_dataset = maybe_convert_to_dataset_name(4)
shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
================================================
FILE: nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py
================================================
================================================
FILE: nnunetv2/dataset_conversion/generate_dataset_json.py
================================================
from typing import Tuple, Union, List
from batchgenerators.utilities.file_and_folder_operations import save_json, join
def generate_dataset_json(output_folder: str,
channel_names: dict,
labels: dict,
num_training_cases: int,
file_ending: str,
citation: Union[List[str], str] = None,
regions_class_order: Tuple[int, ...] = None,
dataset_name: str = None,
reference: str = None,
release: str = None,
description: str = None,
overwrite_image_reader_writer: str = None,
license: str = 'Whoever converted this dataset was lazy and didn\'t look it up!',
converted_by: str = "Please enter your name, especially when sharing datasets with others in a common infrastructure!",
**kwargs):
"""
Generates a dataset.json file in the output folder
channel_names:
Channel names must map the index to the name of the channel, example:
{
0: 'T1',
1: 'CT'
}
Note that the channel names may influence the normalization scheme!! Learn more in the documentation.
labels:
This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not.
Example regular labels:
{
'background': 0,
'left atrium': 1,
'some other label': 2
}
Example region-based training:
{
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': 3
}
Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background!
num_training_cases: is used to double check all cases are there!
file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and
segmentations!
dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for
completeness and as a reminder that these would be great!
overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from
BaseReaderWriter, place it into nnunet.imageio and reference it here by name
kwargs: whatever you put here will be placed in the dataset.json as well
"""
has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()])
if has_regions:
assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \
f"You need that."
# channel names need strings as keys
keys = list(channel_names.keys())
for k in keys:
if not isinstance(k, str):
channel_names[str(k)] = channel_names[k]
del channel_names[k]
# labels need ints as values
for l in labels.keys():
value = labels[l]
if isinstance(value, (tuple, list)):
value = tuple([int(i) for i in value])
labels[l] = value
else:
labels[l] = int(labels[l])
dataset_json = {
'channel_names': channel_names, # previously this was called 'modality'. I didn't like this so this is
# channel_names now. Live with it.
'labels': labels,
'numTraining': num_training_cases,
'file_ending': file_ending,
'licence': license,
'converted_by': converted_by
}
if dataset_name is not None:
dataset_json['name'] = dataset_name
if reference is not None:
dataset_json['reference'] = reference
if release is not None:
dataset_json['release'] = release
if citation is not None:
dataset_json['citation'] = release
if description is not None:
dataset_json['description'] = description
if overwrite_image_reader_writer is not None:
dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer
if regions_class_order is not None:
dataset_json['regions_class_order'] = regions_class_order
dataset_json.update(kwargs)
save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
================================================
FILE: nnunetv2/ensembling/__init__.py
================================================
================================================
FILE: nnunetv2/ensembling/ensemble.py
================================================
import argparse
import multiprocessing
import shutil
from copy import deepcopy
from typing import List, Union, Tuple
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \
maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def average_probabilities(list_of_files: List[str]) -> np.ndarray:
assert len(list_of_files), 'At least one file must be given in list_of_files'
avg = None
for f in list_of_files:
if avg is None:
avg = np.load(f)['probabilities']
# maybe increase precision to prevent rounding errors
if avg.dtype != np.float32:
avg = avg.astype(np.float32)
else:
avg += np.load(f)['probabilities']
avg /= len(list_of_files)
return avg
def merge_files(list_of_files,
output_filename_truncated: str,
output_file_ending: str,
image_reader_writer: BaseReaderWriter,
label_manager: LabelManager,
save_probabilities: bool = False):
# load the pkl file associated with the first file in list_of_files
properties = load_pickle(list_of_files[0][:-4] + '.pkl')
# load and average predictions
probabilities = average_probabilities(list_of_files)
segmentation = label_manager.convert_logits_to_segmentation(probabilities)
image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties)
if save_probabilities:
np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities)
save_pickle(probabilities, output_filename_truncated + '.pkl')
def ensemble_folders(list_of_input_folders: List[str],
output_folder: str,
save_merged_probabilities: bool = False,
num_processes: int = default_num_processes,
dataset_json_file_or_dict: str = None,
plans_json_file_or_dict: str = None):
"""we need too much shit for this function. Problem is that we now have to support region-based training plus
multiple input/output formats so there isn't really a way around this.
If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json
and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction.
We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5
folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only
works if label dict in dataset.json is the same between these datasets!!!)"""
if dataset_json_file_or_dict is not None:
if isinstance(dataset_json_file_or_dict, str):
dataset_json = load_json(dataset_json_file_or_dict)
else:
dataset_json = dataset_json_file_or_dict
else:
dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json'))
if plans_json_file_or_dict is not None:
if isinstance(plans_json_file_or_dict, str):
plans = load_json(plans_json_file_or_dict)
else:
plans = plans_json_file_or_dict
else:
plans = load_json(join(list_of_input_folders[0], 'plans.json'))
plans_manager = PlansManager(plans)
# now collect the files in each of the folders and enforce that all files are present in all folders
files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders]
# first build a set with all files
s = deepcopy(files_per_folder[0])
for f in files_per_folder[1:]:
s.update(f)
for f in files_per_folder:
assert len(s.difference(f)) == 0, "Not all folders contain the same files for ensembling. Please only " \
"provide folders that contain the predictions"
lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s]
output_files_truncated = [join(output_folder, fi[:-4]) for fi in s]
image_reader_writer = plans_manager.image_reader_writer_class()
label_manager = plans_manager.get_label_manager(dataset_json)
maybe_mkdir_p(output_folder)
shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder)
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
num_preds = len(s)
_ = pool.starmap(
merge_files,
zip(
lists_of_lists_of_files,
output_files_truncated,
[dataset_json['file_ending']] * num_preds,
[image_reader_writer] * num_preds,
[label_manager] * num_preds,
[save_merged_probabilities] * num_preds
)
)
def entry_point_ensemble_folders():
parser = argparse.ArgumentParser()
parser.add_argument('-i', nargs='+', type=str, required=True,
help='list of input folders')
parser.add_argument('-o', type=str, required=True, help='output folder')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f"Numbers of processes used for ensembling. Default: {default_num_processes}")
parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output '
'probabilities in separate .npz files')
args = parser.parse_args()
ensemble_folders(args.i, args.o, args.save_npz, args.np)
def ensemble_crossvalidations(list_of_trained_model_folders: List[str],
output_folder: str,
folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4),
num_processes: int = default_num_processes,
overwrite: bool = True) -> None:
"""
Feature: different configurations can now have different splits
"""
dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json'))
plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json'))
# first collect all unique filenames
files_per_folder = {}
unique_filenames = set()
for tr in list_of_trained_model_folders:
files_per_folder[tr] = {}
for f in folds:
if not isdir(join(tr, f'fold_{f}', 'validation')):
raise RuntimeError(f'Expected model output directory does not exist. You must train all requested '
f'folds of the specified model.\nModel: {tr}\nFold: {f}')
files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)
if len(files_here) == 0:
raise RuntimeError(f"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your "
f"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.")
files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)
unique_filenames.update(files_per_folder[tr][f])
# verify that all trained_model_folders have all predictions
ok = True
for tr, fi in files_per_folder.items():
all_files_here = set()
for f in folds:
all_files_here.update(fi[f])
diff = unique_filenames.difference(all_files_here)
if len(diff) > 0:
ok = False
print(f'model {tr} does not seem to contain all predictions. Missing: {diff}')
if not ok:
raise RuntimeError('There were missing files, see print statements above this one')
# now we need to collect where these files are
file_mapping = []
for tr in list_of_trained_model_folders:
file_mapping.append({})
for f in folds:
for fi in files_per_folder[tr][f]:
# check for duplicates
assert fi not in file_mapping[-1].keys(), f"Duplicate detected. Case {fi} is present in more than " \
f"one fold of model {tr}."
file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi)
lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames]
output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames]
image_reader_writer = plans_manager.image_reader_writer_class()
maybe_mkdir_p(output_folder)
label_manager = plans_manager.get_label_manager(dataset_json)
if not overwrite:
tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated]
lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]]
output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]]
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
num_preds = len(lists_of_lists_of_files)
_ = pool.starmap(
merge_files,
zip(
lists_of_lists_of_files,
output_files_truncated,
[dataset_json['file_ending']] * num_preds,
[image_reader_writer] * num_preds,
[label_manager] * num_preds,
[False] * num_preds
)
)
shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json'))
shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json'))
================================================
FILE: nnunetv2/evaluation/__init__.py
================================================
================================================
FILE: nnunetv2/evaluation/accumulate_cv_results.py
================================================
import shutil
from typing import Union, List, Tuple
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile
from nnunetv2.configuration import default_num_processes
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def accumulate_cv_results(trained_model_folder,
merged_output_folder: str,
folds: Union[List[int], Tuple[int, ...]],
num_processes: int = default_num_processes,
overwrite: bool = True):
"""
There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to
collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files!
"""
if overwrite and isdir(merged_output_folder):
shutil.rmtree(merged_output_folder)
maybe_mkdir_p(merged_output_folder)
dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
rw = plans_manager.image_reader_writer_class()
shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json'))
shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json'))
did_we_copy_something = False
for f in folds:
expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation')
if not isdir(expected_validation_folder):
raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!")
predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False)
for pf in predicted_files:
if overwrite and isfile(join(merged_output_folder, pf)):
raise RuntimeError(f'More than one of your folds has a prediction for case {pf}')
if overwrite or not isfile(join(merged_output_folder, pf)):
shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf))
did_we_copy_something = True
if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')):
label_manager = plans_manager.get_label_manager(dataset_json)
gt_folder = join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr')
if not isdir(gt_folder):
gt_folder = join(nnUNet_preprocessed, plans_manager.dataset_name, 'gt_segmentations')
compute_metrics_on_folder(gt_folder,
merged_output_folder,
join(merged_output_folder, 'summary.json'),
rw,
dataset_json['file_ending'],
label_manager.foreground_regions if label_manager.has_regions else
label_manager.foreground_labels,
label_manager.ignore_label,
num_processes)
================================================
FILE: nnunetv2/evaluation/evaluate_predictions.py
================================================
import multiprocessing
import os
from copy import deepcopy
from typing import Tuple, List, Union
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \
isfile
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \
determine_reader_writer_from_file_ending
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]):
return str(label_or_region)
def key_to_label_or_region(key: str):
try:
return int(key)
except ValueError:
key = key.replace('(', '')
key = key.replace(')', '')
split = key.split(',')
return tuple([int(i) for i in split if len(i) > 0])
def save_summary_json(results: dict, output_file: str):
"""
json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit
ourselves
"""
results_converted = deepcopy(results)
# convert keys in mean metrics
results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()}
# convert metric_per_case
for i in range(len(results_converted["metric_per_case"])):
results_converted["metric_per_case"][i]['metrics'] = \
{label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k]
for k in results["metric_per_case"][i]['metrics'].keys()}
# sort_keys=True will make foreground_mean the first entry and thus easy to spot
save_json(results_converted, output_file, sort_keys=True)
def load_summary_json(filename: str):
results = load_json(filename)
# convert keys in mean metrics
results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()}
# convert metric_per_case
for i in range(len(results["metric_per_case"])):
results["metric_per_case"][i]['metrics'] = \
{key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k]
for k in results["metric_per_case"][i]['metrics'].keys()}
return results
def labels_to_list_of_regions(labels: List[int]):
return [(i,) for i in labels]
def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray:
if np.isscalar(region_or_label):
return segmentation == region_or_label
else:
mask = np.zeros_like(segmentation, dtype=bool)
for r in region_or_label:
mask[segmentation == r] = True
return mask
def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):
if ignore_mask is None:
use_mask = np.ones_like(mask_ref, dtype=bool)
else:
use_mask = ~ignore_mask
tp = np.sum((mask_ref & mask_pred) & use_mask)
fp = np.sum(((~mask_ref) & mask_pred) & use_mask)
fn = np.sum((mask_ref & (~mask_pred)) & use_mask)
tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)
return tp, fp, fn, tn
def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter,
labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]],
ignore_label: int = None) -> dict:
# load images
seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)
ignore_mask = seg_ref == ignore_label if ignore_label is not None else None
results = {}
results['reference_file'] = reference_file
results['prediction_file'] = prediction_file
results['metrics'] = {}
for r in labels_or_regions:
results['metrics'][r] = {}
mask_ref = region_or_label_to_mask(seg_ref, r)
mask_pred = region_or_label_to_mask(seg_pred, r)
tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)
if tp + fp + fn == 0:
results['metrics'][r]['Dice'] = np.nan
results['metrics'][r]['IoU'] = np.nan
else:
results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn)
results['metrics'][r]['IoU'] = tp / (tp + fp + fn)
results['metrics'][r]['FP'] = fp
results['metrics'][r]['TP'] = tp
results['metrics'][r]['FN'] = fn
results['metrics'][r]['TN'] = tn
results['metrics'][r]['n_pred'] = fp + tp
results['metrics'][r]['n_ref'] = fn + tp
return results
def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str,
image_reader_writer: BaseReaderWriter,
file_ending: str,
regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]],
ignore_label: int = None,
num_processes: int = default_num_processes,
chill: bool = True) -> dict:
"""
output_file must end with .json; can be None
"""
if output_file is not None:
assert output_file.endswith('.json'), 'output_file should end with .json'
files_pred = subfiles(folder_pred, suffix=file_ending, join=False)
files_ref = subfiles(folder_ref, suffix=file_ending, join=False)
if not chill:
present = [isfile(join(folder_pred, i)) for i in files_ref]
assert all(present), "Not all files in folder_ref exist in folder_pred"
files_ref = [join(folder_ref, i) for i in files_pred]
files_pred = [join(folder_pred, i) for i in files_pred]
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
# for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))):
# compute_metrics(*i)
results = pool.starmap(
compute_metrics,
list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred),
[ignore_label] * len(files_pred)))
)
# mean metric per class
metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys())
means = {}
for r in regions_or_labels:
means[r] = {}
for m in metric_list:
means[r][m] = np.nanmean([i['metrics'][r][m] for i in results])
# foreground mean
foreground_mean = {}
for m in metric_list:
values = []
for k in means.keys():
if k == 0 or k == '0':
continue
values.append(means[k][m])
foreground_mean[m] = np.mean(values)
[recursive_fix_for_json_export(i) for i in results]
recursive_fix_for_json_export(means)
recursive_fix_for_json_export(foreground_mean)
result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean}
if output_file is not None:
save_summary_json(result, output_file)
return result
# print('DONE')
def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str,
output_file: str = None,
num_processes: int = default_num_processes,
chill: bool = False):
dataset_json = load_json(dataset_json_file)
# get file ending
file_ending = dataset_json['file_ending']
# get reader writer class
example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0]
rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)()
# maybe auto set output file
if output_file is None:
output_file = join(folder_pred, 'summary.json')
lm = PlansManager(plans_file).get_label_manager(dataset_json)
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label,
num_processes, chill=chill)
def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]],
output_file: str = None,
num_processes: int = default_num_processes,
ignore_label: int = None,
chill: bool = False):
example_file = subfiles(folder_ref, join=True)[0]
file_ending = os.path.splitext(example_file)[-1]
rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True,
verbose=False)()
# maybe auto set output file
if output_file is None:
output_file = join(folder_pred, 'summary.json')
compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill)
def evaluate_folder_entry_point():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
parser.add_argument('-djfile', type=str, required=True,
help='dataset.json file')
parser.add_argument('-pfile', type=str, required=True,
help='plans.json file')
parser.add_argument('-o', type=str, required=False, default=None,
help='Output file. Optional. Default: pred_folder/summary.json')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f'number of processes used. Optional. Default: {default_num_processes}')
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')
args = parser.parse_args()
compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill)
def evaluate_simple_entry_point():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
parser.add_argument('-l', type=int, nargs='+', required=True,
help='list of labels')
parser.add_argument('-il', type=int, required=False, default=None,
help='ignore label')
parser.add_argument('-o', type=str, required=False, default=None,
help='Output file. Optional. Default: pred_folder/summary.json')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f'number of processes used. Optional. Default: {default_num_processes}')
parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred does not have all files that are present in folder_gt')
args = parser.parse_args()
compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill)
if __name__ == '__main__':
folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr'
folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation'
output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json'
image_reader_writer = SimpleITKIO()
file_ending = '.nii.gz'
regions = labels_to_list_of_regions([1, 2])
ignore_label = None
num_processes = 12
compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label,
num_processes)
================================================
FILE: nnunetv2/evaluation/find_best_configuration.py
================================================
import argparse
import os.path
from copy import deepcopy
from typing import Union, List, Tuple
from batchgenerators.utilities.file_and_folder_operations import (
load_json, join, isdir, listdir, save_json
)
from nnunetv2.configuration import default_num_processes
from nnunetv2.ensembling.ensemble import ensemble_crossvalidations
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results
from nnunetv2.postprocessing.remove_connected_components import determine_postprocessing
from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \
convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
default_trained_models = tuple([
{'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'},
{'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'},
{'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'},
{'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'},
])
def filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]):
valid = []
for trained_model in model_dict:
plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id),
trained_model['plans'] + '.json'))
# check if configuration exists
# 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not
# specified in the plans file
if trained_model['configuration'] not in plans_manager.available_configurations:
print(f"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\n"
f"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.")
continue
# check if trained model output folder exists. This is a requirement. No mercy here.
expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'],
trained_model['configuration'], fold=None)
if not isdir(expected_output_folder):
raise RuntimeError(f"Trained model {trained_model} does not have an output folder. "
f"Expected: {expected_output_folder}. Please run the training for this model! (don't forget "
f"the --npz flag if you want to ensemble multiple configurations)")
valid.append(trained_model)
return valid
def generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str,
plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer',
folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),
folder_with_segs_from_prev_stage: str = None,
input_folder: str = 'INPUT_FOLDER',
output_folder: str = 'OUTPUT_FOLDER',
save_npz: bool = False):
fold_str = ''
for f in folds:
fold_str += f' {f}'
predict_command = ''
trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None)
plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
configuration_manager = plans_manager.get_configuration(configuration_name)
if 'previous_stage' in plans_manager.available_configurations:
prev_stage = configuration_manager.previous_stage_name
predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name,
folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\n'
folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE'
predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \
f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}'
if folder_with_segs_from_prev_stage is not None:
predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}'
if save_npz:
predict_command += ' --save_probabilities'
return predict_command
def find_best_configuration(dataset_name_or_id,
allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models,
allow_ensembling: bool = True,
num_processes: int = default_num_processes,
overwrite: bool = True,
folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),
strict: bool = False):
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
all_results = {}
allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id)
for m in allowed_trained_models:
output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None)
if not isdir(output_folder) and strict:
raise RuntimeError(f'{dataset_name}: The output folder of plans {m["plans"]} configuration '
f'{m["configuration"]} is missing. Please train the model (all requested folds!) first!')
identifier = os.path.basename(output_folder)
merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite)
all_results[identifier] = {
'source': merged_output_folder,
'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice']
}
if allow_ensembling:
for i in range(len(allowed_trained_models)):
for j in range(i + 1, len(allowed_trained_models)):
m1, m2 = allowed_trained_models[i], allowed_trained_models[j]
output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None)
output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None)
identifier = get_ensemble_name(output_folder_1, output_folder_2, folds)
output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier)
ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds,
num_processes, overwrite=overwrite)
# evaluate ensembled predictions
plans_manager = PlansManager(join(output_folder_1, 'plans.json'))
dataset_json = load_json(join(output_folder_1, 'dataset.json'))
label_manager = plans_manager.get_label_manager(dataset_json)
rw = plans_manager.image_reader_writer_class()
compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),
output_folder_ensemble,
join(output_folder_ensemble, 'summary.json'),
rw,
dataset_json['file_ending'],
label_manager.foreground_regions if label_manager.has_regions else
label_manager.foreground_labels,
label_manager.ignore_label,
num_processes)
all_results[identifier] = \
{
'source': output_folder_ensemble,
'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice']
}
# pick best and report inference command
best_score = max([i['result'] for i in all_results.values()])
best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score] # may never happen but theoretically
# there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles
# come after single configs)
best_key = best_keys[0]
print()
print('***All results:***')
for k, v in all_results.items():
print(f'{k}: {v["result"]}')
print(f'\n*Best*: {best_key}: {all_results[best_key]["result"]}')
print()
print('***Determining postprocessing for best model/ensemble***')
determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),
plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'),
dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'),
num_processes=num_processes, keep_postprocessed_files=True)
# in addition to just reading the console output (how it was previously) we should return the information
# needed to run the full inference via API
return_dict = {
'folds': folds,
'dataset_name_or_id': dataset_name_or_id,
'considered_models': allowed_trained_models,
'ensembling_allowed': allow_ensembling,
'all_results': {i: j['result'] for i, j in all_results.items()},
'best_model_or_ensemble': {
'result_on_crossval_pre_pp': all_results[best_key]["result"],
'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'],
'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'),
'some_plans_file': join(all_results[best_key]['source'], 'plans.json'),
# just needed for label handling, can
# come from any of the ensemble members (if any)
'selected_model_or_models': []
}
}
# convert best key to inference command:
if best_key.startswith('ensemble___'):
prefix, m1, m2, folds_string = best_key.split('___')
tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1)
tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2)
return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
{
'configuration': c1,
'trainer': tr1,
'plans_identifier': pl1,
})
return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
{
'configuration': c2,
'trainer': tr2,
'plans_identifier': pl2,
})
else:
tr, pl, c = convert_identifier_to_trainer_plans_config(best_key)
return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
{
'configuration': c,
'trainer': tr,
'plans_identifier': pl,
})
save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json')) # save this so that we don't have to run this
# everything someone wants to be reminded of the inference commands. They can just load this and give it to
# print_inference_instructions
# print it
print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt'))
return return_dict
def print_inference_instructions(inference_info_dict: dict, instructions_file: str = None):
def _print_and_maybe_write_to_file(string):
print(string)
if f_handle is not None:
f_handle.write(f'{string}\n')
f_handle = open(instructions_file, 'w') if instructions_file is not None else None
print()
_print_and_maybe_write_to_file('***Run inference like this:***\n')
output_folders = []
dataset_name_or_id = inference_info_dict['dataset_name_or_id']
if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1:
is_ensemble = True
_print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\n')
else:
is_ensemble = False
for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']):
tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier']
if is_ensemble:
output_folder_name = f"OUTPUT_FOLDER_MODEL_{j+1}"
else:
output_folder_name = f"OUTPUT_FOLDER"
output_folders.append(output_folder_name)
_print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'],
save_npz=is_ensemble, output_folder=output_folder_name))
if is_ensemble:
output_folder_str = output_folders[0]
for o in output_folders[1:]:
output_folder_str += f' {o}'
output_ensemble = f"OUTPUT_FOLDER"
_print_and_maybe_write_to_file('\nThe run ensembling with:\n')
_print_and_maybe_write_to_file(f"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}")
_print_and_maybe_write_to_file("\n***Once inference is completed, run postprocessing like this:***\n")
_print_and_maybe_write_to_file(f"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP "
f"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} "
f"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}")
def dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]):
"""
function is called dumb because it's dumb
"""
ret = []
for t in trainers:
for c in configs:
for p in plans:
ret.append(
{'plans': p, 'configuration': c, 'trainer': t}
)
return tuple(ret)
def find_best_configuration_entry_point():
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')
parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'],
help='List of plan identifiers. Default: nnUNetPlans')
parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'],
help="List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']")
parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'],
help='List of trainers. Default: nnUNetTrainer')
parser.add_argument('-np', required=False, default=default_num_processes, type=int,
help='Number of processes to use for ensembling, postprocessing etc')
parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),
help='Folds to use. Default: 0 1 2 3 4')
parser.add_argument('--disable_ensembling', action='store_true', required=False,
help='Set this flag to disable ensembling')
parser.add_argument('--no_overwrite', action='store_true',
help='If set we will not overwrite already ensembled files etc. May speed up consecutive '
'runs of this command (why would you want to do that?) at the risk of not updating '
'outdated results.')
args = parser.parse_args()
model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p)
dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id)
find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling,
num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f,
strict=False)
def accumulate_crossval_results_entry_point():
parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint '
'folder and evaluates them')
parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')
parser.add_argument('-c', type=str, required=True,
default='3d_fullres',
help="Configuration")
parser.add_argument('-o', type=str, required=False, default=None,
help="Output folder. If not specified, the output folder will be located in the trained " \
"model directory (named crossval_results_folds_XXX).")
parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),
help='Folds to use. Default: 0 1 2 3 4')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='Trainer class. Default: nnUNetTrainer')
args = parser.parse_args()
trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c)
if args.o is None:
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')
else:
merged_output_folder = args.o
if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0:
raise FileExistsError(
f"Output folder {merged_output_folder} exists and is not empty. "
f"To avoid data loss, nnUNet requires an empty output folder."
)
accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)
if __name__ == '__main__':
find_best_configuration(4,
default_trained_models,
True,
8,
False,
(0, 1, 2, 3, 4))
================================================
FILE: nnunetv2/experiment_planning/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/dataset_fingerprint/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py
================================================
import multiprocessing
import os
from time import sleep
from typing import List, Type, Union
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p
from tqdm import tqdm
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
class DatasetFingerprintExtractor(object):
def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False):
"""
extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a
json file in the input_folder
Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere
else. Don't compute stuff we don't need (except for intensity_statistics_per_channel)
"""
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
self.verbose = verbose
self.show_progress_bar = True
self.dataset_name = dataset_name
self.input_folder = join(nnUNet_raw, dataset_name)
self.num_processes = num_processes
self.dataset_json = load_json(join(self.input_folder, 'dataset.json'))
self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json)
# We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is
# also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total
# (for the entire dataset)
self.num_foreground_voxels_for_intensitystats = 10e7
@staticmethod
def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234,
num_samples: int = 10000):
"""
images=image with multiple channels = shape (c, x, y(, z))
"""
assert images.ndim == 4 and segmentation.ndim == 4
assert not np.any(np.isnan(segmentation)), "Segmentation contains NaN values. grrrr.... :-("
assert not np.any(np.isnan(images)), "Images contains NaN values. grrrr.... :-("
rs = np.random.RandomState(seed)
intensities_per_channel = []
# we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have
intensity_statistics_per_channel = []
# segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work
foreground_mask = segmentation[0] > 0
percentiles = np.array((0.5, 50.0, 99.5))
for i in range(len(images)):
foreground_pixels = images[i][foreground_mask]
num_fg = len(foreground_pixels)
# sample with replacement so that we don't get issues with cases that have less than num_samples
# foreground_pixels. We could also just sample less in those cases but that would than cause these
# training cases to be underrepresented
intensities_per_channel.append(
rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else [])
mean, median, mini, maxi, percentile_99_5, percentile_00_5 = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
if num_fg > 0:
percentile_00_5, median, percentile_99_5 = np.percentile(foreground_pixels, percentiles)
mean = np.mean(foreground_pixels)
mini = np.min(foreground_pixels)
maxi = np.max(foreground_pixels)
intensity_statistics_per_channel.append({
'mean': mean,
'median': median,
'min': mini,
'max': maxi,
'percentile_99_5': percentile_99_5,
'percentile_00_5': percentile_00_5,
})
return intensities_per_channel, intensity_statistics_per_channel
@staticmethod
def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter],
num_samples: int = 10000):
rw = reader_writer_class()
images, properties_images = rw.read_images(image_files)
segmentation, properties_seg = rw.read_seg(segmentation_file)
# we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly.
# Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't
# need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this
# way. This is only possible because we are now using our new input/output interface.
data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation)
foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \
DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped,
num_samples=num_samples)
spacing = properties_images['spacing']
shape_before_crop = images.shape[1:]
shape_after_crop = data_cropped.shape[1:]
relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop)
return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \
relative_size_after_cropping
def run(self, overwrite_existing: bool = False) -> dict:
# we do not save the properties file in self.input_folder because that folder might be read-only. We can only
# reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is
preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name)
maybe_mkdir_p(preprocessed_output_folder)
properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json')
if not isfile(properties_file) or overwrite_existing:
reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json,
# yikes. Rip the following line
self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0])
# determine how many foreground voxels we need to sample per training case
num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats //
len(self.dataset))
r = []
with multiprocessing.get_context("spawn").Pool(self.num_processes) as p:
for k in self.dataset.keys():
r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case,
((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class,
num_foreground_samples_per_case),)))
remaining = list(range(len(self.dataset)))
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
# So we need to store the original pool of workers.
workers = [j for j in p._pool]
with tqdm(desc="Extracting dataset fingerprint", total=len(self.dataset),
disable=not getattr(self, 'show_progress_bar', True)) as pbar:
while len(remaining) > 0:
all_alive = all([j.is_alive() for j in workers])
if not all_alive:
raise RuntimeError('Some background worker is 6 feet under. Yuck. \n'
'OK jokes aside.\n'
'One of your background processes is missing. This could be because of '
'an error (look for an error message) or because it was killed '
'by your OS due to running out of RAM. If you don\'t see '
'an error message, out of RAM is likely the problem. In that case '
'reducing the number of workers might help')
done = [i for i in remaining if r[i].ready()]
for _ in done:
pbar.update()
remaining = [i for i in remaining if i not in done]
sleep(0.1)
# results = ptqdm(DatasetFingerprintExtractor.analyze_case,
# (training_images_per_case, training_labels_per_case),
# processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class,
# num_samples=num_foreground_samples_per_case, disable=self.verbose)
results = [i.get()[0] for i in r]
shapes_after_crop = [r[0] for r in results]
spacings = [r[1] for r in results]
foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in
range(len(results[0][2]))]
foreground_intensities_per_channel = np.array(foreground_intensities_per_channel)
# we drop this so that the json file is somewhat human readable
# foreground_intensity_stats_by_case_and_modality = [r[3] for r in results]
median_relative_size_after_cropping = np.median([r[4] for r in results], 0)
num_channels = len(self.dataset_json['channel_names'].keys()
if 'channel_names' in self.dataset_json.keys()
else self.dataset_json['modality'].keys())
intensity_statistics_per_channel = {}
percentiles = np.array((0.5, 50.0, 99.5))
for i in range(num_channels):
percentile_00_5, median, percentile_99_5 = np.percentile(foreground_intensities_per_channel[i],
percentiles)
intensity_statistics_per_channel[i] = {
'mean': float(np.mean(foreground_intensities_per_channel[i])),
'median': float(median),
'std': float(np.std(foreground_intensities_per_channel[i])),
'min': float(np.min(foreground_intensities_per_channel[i])),
'max': float(np.max(foreground_intensities_per_channel[i])),
'percentile_99_5': float(percentile_99_5),
'percentile_00_5': float(percentile_00_5),
}
fingerprint = {
"spacings": spacings,
"shapes_after_crop": shapes_after_crop,
'foreground_intensity_properties_per_channel': intensity_statistics_per_channel,
"median_relative_size_after_cropping": median_relative_size_after_cropping
}
try:
save_json(fingerprint, properties_file)
except Exception as e:
if isfile(properties_file):
os.remove(properties_file)
raise e
else:
fingerprint = load_json(properties_file)
return fingerprint
if __name__ == '__main__':
dfe = DatasetFingerprintExtractor(2, 8)
dfe.run(overwrite_existing=False)
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py
================================================
import shutil
from copy import deepcopy
from typing import List, Union, Tuple
import numpy as np
import torch
from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p
from dynamic_network_architectures.architectures.unet import PlainConvUNet
from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme
from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
class ExperimentPlanner(object):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
"""
overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may
also be affected
"""
self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
self.suppress_transpose = suppress_transpose
self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name)
preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name)
self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json'))
self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json)
# load dataset fingerprint
if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')):
raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNetv2_extract_fingerprint')
self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json'))
self.anisotropy_threshold = ANISO_THRESHOLD
self.UNet_base_num_features = 32
self.UNet_class = PlainConvUNet
# the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as
# much as possible
self.UNet_reference_val_3d = 560000000 # 455600128 550000000
self.UNet_reference_val_2d = 85000000 # 83252480
self.UNet_reference_com_nfeatures = 32
self.UNet_reference_val_corresp_GB = 8
self.UNet_reference_val_corresp_bs_2d = 12
self.UNet_reference_val_corresp_bs_3d = 2
self.UNet_featuremap_min_edge_length = 4
self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
self.UNet_min_batch_size = 2
self.UNet_max_features_2d = 512
self.UNet_max_features_3d = 320
self.max_dataset_covered = 0.05 # we limit the batch size so that no more than 5% of the dataset can be seen
# in a single forward/backward pass
self.UNet_vram_target_GB = gpu_memory_target_in_gb
self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the
# median shape then we need a lowres config as well
self.preprocessor_name = preprocessor_name
self.plans_identifier = plans_name
self.overwrite_target_spacing = overwrite_target_spacing
assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \
'used then three floats must be ' \
'given (as list or tuple)'
assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \
'if overwrite_target_spacing is used then three floats must be given (as list or tuple)'
self.plans = None
if isfile(join(self.raw_dataset_folder, 'splits_final.json')):
_maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'),
join(preprocessed_folder, 'splits_final.json'))
def determine_reader_writer(self):
example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]
return determine_reader_writer_from_dataset_json(self.dataset_json, example_image)
@staticmethod
def static_estimate_VRAM_usage(patch_size: Tuple[int],
input_channels: int,
output_channels: int,
arch_class_name: str,
arch_kwargs: dict,
arch_kwargs_req_import: Tuple[str, ...]):
"""
Works for PlainConvUNet, ResidualEncoderUNet
"""
a = torch.get_num_threads()
torch.set_num_threads(get_allowed_n_proc_DA())
# print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}')
net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels,
output_channels,
allow_init=False)
ret = net.compute_conv_feature_map_size(patch_size)
torch.set_num_threads(a)
return ret
def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_data_or_seg_to_shape
resampling_data_kwargs = {
"is_seg": False,
"order": 3,
"order_z": 0,
"force_separate_z": None,
}
resampling_seg = resample_data_or_seg_to_shape
resampling_seg_kwargs = {
"is_seg": True,
"order": 1,
"order_z": 0,
"force_separate_z": None,
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_data_or_seg_to_shape
resampling_fn_kwargs = {
"is_seg": False,
"order": 1,
"order_z": 0,
"force_separate_z": None,
}
return resampling_fn, resampling_fn_kwargs
def determine_fullres_target_spacing(self) -> np.ndarray:
"""
per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
(for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low
resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
impact performance (due to the low number of slices).
"""
if self.overwrite_target_spacing is not None:
return np.array(self.overwrite_target_spacing)
spacings = np.vstack(self.dataset_fingerprint['spacings'])
sizes = self.dataset_fingerprint['shapes_after_crop']
target = np.percentile(spacings, 50, 0)
# todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)]
target_size = np.percentile(np.vstack(sizes), 50, 0)
# we need to identify datasets for which a different target spacing could be beneficial. These datasets have
# the following properties:
# - one axis which much lower resolution than the others
# - the lowres axis has much less voxels than the others
# - (the size in mm of the lowres axis is also reduced)
worst_spacing_axis = np.argmax(target)
other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
other_spacings = [target[i] for i in other_axes]
other_sizes = [target_size[i] for i in other_axes]
has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
if has_aniso_spacing and has_aniso_voxels:
spacings_of_that_axis = spacings[:, worst_spacing_axis]
target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
# don't let the spacing of that axis get higher than the other axes
if target_spacing_of_that_axis < max(other_spacings):
target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5
target[worst_spacing_axis] = target_spacing_of_that_axis
return target
def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]:
if 'channel_names' not in self.dataset_json.keys():
print('WARNING: "modalities" should be renamed to "channel_names" in dataset.json. This will be '
'enforced soon!')
modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \
self.dataset_json['modality']
normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()]
if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.):
use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in
normalization_schemes]
else:
use_nonzero_mask_for_norm = [False] * len(normalization_schemes)
assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \
'True or False and cannot be None'
normalization_schemes = [i.__name__ for i in normalization_schemes]
return normalization_schemes, use_nonzero_mask_for_norm
def determine_transpose(self):
if self.suppress_transpose:
return [0, 1, 2], [0, 1, 2]
# todo we should use shapes for that as well. Not quite sure how yet
target_spacing = self.determine_fullres_target_spacing()
max_spacing_axis = np.argmax(target_spacing)
remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
transpose_forward = [max_spacing_axis] + remaining_axes
transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)]
return transpose_forward, transpose_backward
def get_plans_for_configuration(self,
spacing: Union[np.ndarray, Tuple[float, ...], List[float]],
median_shape: Union[np.ndarray, Tuple[int, ...]],
data_identifier: str,
approximate_n_voxels_dataset: float,
_cache: dict) -> dict:
def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:
return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for
i in range(num_stages)])
def _keygen(patch_size, strides):
return str(patch_size) + '_' + str(strides)
assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}"
num_input_channels = len(self.dataset_json['channel_names'].keys()
if 'channel_names' in self.dataset_json.keys()
else self.dataset_json['modality'].keys())
max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d
unet_conv_op = convert_dim_to_conv_op(len(spacing))
# print(spacing, median_shape, approximate_n_voxels_dataset)
# find an initial patch size
# we first use the spacing to get an aspect ratio
tmp = 1 / np.array(spacing)
# we then upscale it so that it initially is certainly larger than what we need (rescale to have the same
# volume as a patch of size 256 ** 3)
# this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be
# ideal because large initial patch sizes increase computation time because more iterations in the while loop
# further down may be required.
if len(spacing) == 3:
initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]
elif len(spacing) == 2:
initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]
else:
raise RuntimeError()
# clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that
# this is different from how nnU-Net v1 does it!
# todo patch size can still get too large because we pad the patch size to a multiple of 2**n
initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])])
# use that to get the network topology. Note that this changes the patch_size depending on the number of
# pooling operations (must be divisible by 2**num_pool in each axis)
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
norm = get_matching_instancenorm(unet_conv_op)
architecture_kwargs = {
'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,
'arch_kwargs': {
'n_stages': num_stages,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
'conv_bias': True,
'norm_op': norm.__module__ + '.' + norm.__name__,
'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
'dropout_op': None,
'dropout_op_kwargs': None,
'nonlin': 'torch.nn.LeakyReLU',
'nonlin_kwargs': {'inplace': True},
},
'_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),
}
# now estimate vram consumption
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# how large is the reference for us here (batch size etc)?
# adapt for our vram target
reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \
(self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)
ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d
# we enforce a batch size of at least two, reference values may have been computed for different batch sizes.
# Correct for that in the while loop if statement
while (estimate / ref_bs * 2) > reference:
# print(patch_size, estimate, reference)
# patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the
# aspect ratio the most (that is the largest relative to median shape)
axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]
# we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this
# may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.
# If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size
# (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first
# subtract shape_must_be_divisible_by, then recompute it and then subtract the
# recomputed shape_must_be_divisible_by. Annoying.
patch_size = list(patch_size)
tmp = deepcopy(patch_size)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by = \
get_pool_and_conv_props(spacing, tmp,
self.UNet_featuremap_min_edge_length,
999999)
patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
# now recompute topology
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
architecture_kwargs['arch_kwargs'].update({
'n_stages': num_stages,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'n_conv_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
})
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(
patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was
# executed. If not, additional vram headroom is used to increase batch size
batch_size = round((reference / estimate) * ref_bs)
# we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot
# go smaller than self.UNet_min_batch_size though
bs_corresponding_to_5_percent = round(
approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))
batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)
resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()
resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()
normalization_schemes, mask_is_used_for_norm = \
self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()
plan = {
'data_identifier': data_identifier,
'preprocessor_name': self.preprocessor_name,
'batch_size': batch_size,
'patch_size': patch_size,
'median_image_size_in_voxels': median_shape,
'spacing': spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': mask_is_used_for_norm,
'resampling_fn_data': resampling_data.__name__,
'resampling_fn_seg': resampling_seg.__name__,
'resampling_fn_data_kwargs': resampling_data_kwargs,
'resampling_fn_seg_kwargs': resampling_seg_kwargs,
'resampling_fn_probabilities': resampling_softmax.__name__,
'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,
'architecture': architecture_kwargs
}
return plan
def plan_experiment(self):
"""
MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY
Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done
differently for each configuration but this would cause problems with identifying the correct axes for 2d. There
surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes.
So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too
hard.
"""
# we use this as a cache to prevent having to instantiate the architecture too often. Saves computation time
_tmp = {}
# first get transpose
transpose_forward, transpose_backward = self.determine_transpose()
# get fullres spacing and transpose it
fullres_spacing = self.determine_fullres_target_spacing()
fullres_spacing_transposed = fullres_spacing[transpose_forward]
# get transposed new median shape (what we would have after resampling)
new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in
zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])]
new_median_shape = np.median(new_shapes, 0)
new_median_shape_transposed = new_median_shape[transpose_forward]
approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) *
self.dataset_json['numTraining'])
# only run 3d if this is a 3d dataset
if new_median_shape_transposed[0] != 1:
plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed,
new_median_shape_transposed,
self.generate_data_identifier('3d_fullres'),
approximate_n_voxels_dataset, _tmp)
# maybe add 3d_lowres as well
patch_size_fullres = plan_3d_fullres['patch_size']
median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64)
num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64)
plan_3d_lowres = None
lowres_spacing = deepcopy(plan_3d_fullres['spacing'])
spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation!
while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold:
# we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they
# is/are similar (factor 2) to the other ax(i/e)s.
max_spacing = max(lowres_spacing)
if np.any((max_spacing / lowres_spacing) > 2):
lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor
else:
lowres_spacing *= spacing_increase_factor
median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed,
dtype=np.float64)
# print(lowres_spacing)
plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing,
tuple([round(i) for i in plan_3d_fullres['spacing'] /
lowres_spacing * new_median_shape_transposed]),
self.generate_data_identifier('3d_lowres'),
float(np.prod(median_num_voxels) *
self.dataset_json['numTraining']), _tmp)
num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64)
print(f'Attempting to find 3d_lowres config. '
f'\nCurrent spacing: {lowres_spacing}. '
f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. '
f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}')
if np.prod(new_median_shape_transposed, dtype=np.float64) / median_num_voxels < 2:
print(f'Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. '
f'3d_fullres: {new_median_shape_transposed}, '
f'3d_lowres: {[round(i) for i in plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed]}')
plan_3d_lowres = None
if plan_3d_lowres is not None:
plan_3d_lowres['batch_dice'] = False
plan_3d_fullres['batch_dice'] = True
else:
plan_3d_fullres['batch_dice'] = False
else:
plan_3d_fullres = None
plan_3d_lowres = None
# 2D configuration
plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:],
new_median_shape_transposed[1:],
self.generate_data_identifier('2d'), approximate_n_voxels_dataset,
_tmp)
plan_2d['batch_dice'] = True
print('2D U-Net configuration:')
print(plan_2d)
print()
# median spacing and shape, just for reference when printing the plans
median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward]
median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward]
# instead of writing all that into the plans we just copy the original file. More files, but less crowded
# per file.
shutil.copy(join(self.raw_dataset_folder, 'dataset.json'),
join(nnUNet_preprocessed, self.dataset_name, 'dataset.json'))
# json is ###. I hate it... "Object of type int64 is not JSON serializable"
plans = {
'dataset_name': self.dataset_name,
'plans_name': self.plans_identifier,
'original_median_spacing_after_transp': [float(i) for i in median_spacing],
'original_median_shape_after_transp': [int(round(i)) for i in median_shape],
'image_reader_writer': self.determine_reader_writer().__name__,
'transpose_forward': [int(i) for i in transpose_forward],
'transpose_backward': [int(i) for i in transpose_backward],
'configurations': {'2d': plan_2d},
'experiment_planner_used': self.__class__.__name__,
'label_manager': 'LabelManager',
'foreground_intensity_properties_per_channel': self.dataset_fingerprint[
'foreground_intensity_properties_per_channel']
}
if plan_3d_lowres is not None:
plans['configurations']['3d_lowres'] = plan_3d_lowres
if plan_3d_fullres is not None:
plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres'
print('3D lowres U-Net configuration:')
print(plan_3d_lowres)
print()
if plan_3d_fullres is not None:
plans['configurations']['3d_fullres'] = plan_3d_fullres
print('3D fullres U-Net configuration:')
print(plan_3d_fullres)
print()
if plan_3d_lowres is not None:
plans['configurations']['3d_cascade_fullres'] = {
'inherits_from': '3d_fullres',
'previous_stage': '3d_lowres'
}
self.plans = plans
self.save_plans(plans)
return plans
def save_plans(self, plans):
recursive_fix_for_json_export(plans)
plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')
# we don't want to overwrite potentially existing custom configurations every time this is executed. So let's
# read the plans file if it already exists and keep any non-default configurations
if isfile(plans_file):
old_plans = load_json(plans_file)
old_configurations = old_plans['configurations']
for c in plans['configurations'].keys():
if c in old_configurations.keys():
del (old_configurations[c])
plans['configurations'].update(old_configurations)
maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name))
save_json(plans, plans_file, sort_keys=False)
print(f"Plans were saved to {join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')}")
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name
def load_plans(self, fname: str):
self.plans = load_json(fname)
def _maybe_copy_splits_file(splits_file: str, target_fname: str):
if not isfile(target_fname):
shutil.copy(splits_file, target_fname)
else:
# split already exists, do not copy, but check that the splits match.
# This code allows target_fname to contain more splits than splits_file. This is OK.
splits_source = load_json(splits_file)
splits_target = load_json(target_fname)
# all folds in the source file must match the target file
for i in range(len(splits_source)):
train_source = set(splits_source[i]['train'])
train_target = set(splits_target[i]['train'])
assert train_target == train_source
val_source = set(splits_source[i]['val'])
val_target = set(splits_target[i]['val'])
assert val_source == val_target
if __name__ == '__main__':
ExperimentPlanner(2, 8).plan_experiment()
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/network_topology.py
================================================
from copy import deepcopy
import numpy as np
def get_shape_must_be_divisible_by(net_numpool_per_axis):
return 2 ** np.array(net_numpool_per_axis)
def pad_shape(shape, must_be_divisible_by):
"""
pads shape so that it is divisible by must_be_divisible_by
:param shape:
:param must_be_divisible_by:
:return:
"""
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
must_be_divisible_by = [must_be_divisible_by] * len(shape)
else:
assert len(must_be_divisible_by) == len(shape)
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
for i in range(len(shape)):
if shape[i] % must_be_divisible_by[i] == 0:
new_shp[i] -= must_be_divisible_by[i]
new_shp = np.array(new_shp).astype(int)
return new_shp
def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
"""
this is the same as get_pool_and_conv_props_v2 from old nnunet
:param spacing:
:param patch_size:
:param min_feature_map_size: min edge length of feature maps in bottleneck
:param max_numpool:
:return:
"""
# todo review this code
dim = len(spacing)
current_spacing = deepcopy(list(spacing))
current_size = deepcopy(list(patch_size))
pool_op_kernel_sizes = [[1] * len(spacing)]
conv_kernel_sizes = []
num_pool_per_axis = [0] * dim
kernel_size = [1] * dim
while True:
# exclude axes that we cannot pool further because of min_feature_map_size constraint
valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
if len(valid_axes_for_pool) < 1:
break
spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
# find axis that are within factor of 2 within smallest spacing
min_spacing_of_valid = min(spacings_of_axes)
valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
# max_numpool constraint
valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
if len(valid_axes_for_pool) == 1:
if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
pass
else:
break
if len(valid_axes_for_pool) < 1:
break
# now we need to find kernel sizes
# kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
# factor 2 of min_spacing. Once they are 3 they remain 3
for d in range(dim):
if kernel_size[d] == 3:
continue
else:
if current_spacing[d] / min(current_spacing) < 2:
kernel_size[d] = 3
other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
pool_kernel_sizes = [0] * dim
for v in valid_axes_for_pool:
pool_kernel_sizes[v] = 2
num_pool_per_axis[v] += 1
current_spacing[v] *= 2
current_size[v] = np.ceil(current_size[v] / 2)
for nv in other_axes:
pool_kernel_sizes[nv] = 1
pool_op_kernel_sizes.append(pool_kernel_sizes)
conv_kernel_sizes.append(deepcopy(kernel_size))
#print(conv_kernel_sizes)
must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
patch_size = pad_shape(patch_size, must_be_divisible_by)
def _to_tuple(lst):
return tuple(_to_tuple(i) if isinstance(i, list) else i for i in lst)
# we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
conv_kernel_sizes.append([3]*dim)
return num_pool_per_axis, _to_tuple(pool_op_kernel_sizes), _to_tuple(conv_kernel_sizes), tuple(patch_size), must_be_divisible_by
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/resampling/planners_no_resampling.py
================================================
from typing import Union, List, Tuple
from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \
nnUNetPlannerResEncL
from nnunetv2.preprocessing.resampling.no_resampling import no_resampling_hack
class nnUNetPlannerResEncL_noResampling(nnUNetPlannerResEncL):
"""
This planner will generate 3d_lowres as well. Don't trust it. Everything will remain in the original shape.
No resampling will ever be done.
"""
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_noResampling',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name
def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = no_resampling_hack
resampling_data_kwargs = {}
resampling_seg = no_resampling_hack
resampling_seg_kwargs = {}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = no_resampling_hack
resampling_fn_kwargs = {}
return resampling_fn, resampling_fn_kwargs
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py
================================================
from typing import Union, List, Tuple
from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \
nnUNetPlannerResEncL
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet
class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name
def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs
class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name
def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_fn, resampling_fn_kwargs
class nnUNetPlanner_torchres(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name
def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py
================================================
import numpy as np
from copy import deepcopy
from typing import Union, List, Tuple
from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
from torch import nn
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props
class ResEncUNetPlanner(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
self.UNet_class = ResidualEncoderUNet
# the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as
# much as possible
self.UNet_reference_val_3d = 680000000
self.UNet_reference_val_2d = 135000000
self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
if configuration_name == '2d' or configuration_name == '3d_fullres':
# we do not deviate from ExperimentPlanner so we can reuse its data
return 'nnUNetPlans' + '_' + configuration_name
else:
return self.plans_identifier + '_' + configuration_name
def get_plans_for_configuration(self,
spacing: Union[np.ndarray, Tuple[float, ...], List[float]],
median_shape: Union[np.ndarray, Tuple[int, ...]],
data_identifier: str,
approximate_n_voxels_dataset: float,
_cache: dict) -> dict:
def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:
return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for
i in range(num_stages)])
def _keygen(patch_size, strides):
return str(patch_size) + '_' + str(strides)
assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}"
num_input_channels = len(self.dataset_json['channel_names'].keys()
if 'channel_names' in self.dataset_json.keys()
else self.dataset_json['modality'].keys())
max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d
unet_conv_op = convert_dim_to_conv_op(len(spacing))
# print(spacing, median_shape, approximate_n_voxels_dataset)
# find an initial patch size
# we first use the spacing to get an aspect ratio
tmp = 1 / np.array(spacing)
# we then upscale it so that it initially is certainly larger than what we need (rescale to have the same
# volume as a patch of size 256 ** 3)
# this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be
# ideal because large initial patch sizes increase computation time because more iterations in the while loop
# further down may be required.
if len(spacing) == 3:
initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]
elif len(spacing) == 2:
initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]
else:
raise RuntimeError()
# clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that
# this is different from how nnU-Net v1 does it!
# todo patch size can still get too large because we pad the patch size to a multiple of 2**n
initial_patch_size = np.minimum(initial_patch_size, median_shape[:len(spacing)])
# use that to get the network topology. Note that this changes the patch_size depending on the number of
# pooling operations (must be divisible by 2**num_pool in each axis)
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
norm = get_matching_instancenorm(unet_conv_op)
architecture_kwargs = {
'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,
'arch_kwargs': {
'n_stages': num_stages,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
'conv_bias': True,
'norm_op': norm.__module__ + '.' + norm.__name__,
'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
'dropout_op': None,
'dropout_op_kwargs': None,
'nonlin': 'torch.nn.LeakyReLU',
'nonlin_kwargs': {'inplace': True},
},
'_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),
}
# now estimate vram consumption
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# how large is the reference for us here (batch size etc)?
# adapt for our vram target
reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \
(self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)
while estimate > reference:
# print(patch_size)
# patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the
# aspect ratio the most (that is the largest relative to median shape)
axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]
# we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this
# may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.
# If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size
# (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first
# subtract shape_must_be_divisible_by, then recompute it and then subtract the
# recomputed shape_must_be_divisible_by. Annoying.
patch_size = list(patch_size)
tmp = deepcopy(patch_size)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by = \
get_pool_and_conv_props(spacing, tmp,
self.UNet_featuremap_min_edge_length,
999999)
patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
# now recompute topology
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
architecture_kwargs['arch_kwargs'].update({
'n_stages': num_stages,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
})
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(
patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was
# executed. If not, additional vram headroom is used to increase batch size
ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d
batch_size = round((reference / estimate) * ref_bs)
# we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot
# go smaller than self.UNet_min_batch_size though
bs_corresponding_to_5_percent = round(
approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))
batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)
resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()
resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()
normalization_schemes, mask_is_used_for_norm = \
self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()
plan = {
'data_identifier': data_identifier,
'preprocessor_name': self.preprocessor_name,
'batch_size': batch_size,
'patch_size': patch_size,
'median_image_size_in_voxels': median_shape,
'spacing': spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': mask_is_used_for_norm,
'resampling_fn_data': resampling_data.__name__,
'resampling_fn_seg': resampling_seg.__name__,
'resampling_fn_data_kwargs': resampling_data_kwargs,
'resampling_fn_seg_kwargs': resampling_seg_kwargs,
'resampling_fn_probabilities': resampling_softmax.__name__,
'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,
'architecture': architecture_kwargs
}
return plan
if __name__ == '__main__':
# we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),
conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None,
nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned
# from this one to match the regular nnunetplans more closely
net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512),
conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1),
conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None,
nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/residual_unets/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py
================================================
import warnings
import numpy as np
from copy import deepcopy
from typing import Union, List, Tuple
from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet
from torch import nn
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props
class ResEncUNetPlanner(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
self.UNet_class = ResidualEncoderUNet
# the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as
# much as possible
self.UNet_reference_val_3d = 680000000
self.UNet_reference_val_2d = 135000000
self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
if configuration_name == '2d' or configuration_name == '3d_fullres':
# we do not deviate from ExperimentPlanner so we can reuse its data
return 'nnUNetPlans' + '_' + configuration_name
else:
return self.plans_identifier + '_' + configuration_name
def get_plans_for_configuration(self,
spacing: Union[np.ndarray, Tuple[float, ...], List[float]],
median_shape: Union[np.ndarray, Tuple[int, ...]],
data_identifier: str,
approximate_n_voxels_dataset: float,
_cache: dict) -> dict:
def _features_per_stage(num_stages, max_num_features) -> Tuple[int, ...]:
return tuple([min(max_num_features, self.UNet_base_num_features * 2 ** i) for
i in range(num_stages)])
def _keygen(patch_size, strides):
return str(patch_size) + '_' + str(strides)
assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}"
num_input_channels = len(self.dataset_json['channel_names'].keys()
if 'channel_names' in self.dataset_json.keys()
else self.dataset_json['modality'].keys())
max_num_features = self.UNet_max_features_2d if len(spacing) == 2 else self.UNet_max_features_3d
unet_conv_op = convert_dim_to_conv_op(len(spacing))
# print(spacing, median_shape, approximate_n_voxels_dataset)
# find an initial patch size
# we first use the spacing to get an aspect ratio
tmp = 1 / np.array(spacing)
# we then upscale it so that it initially is certainly larger than what we need (rescale to have the same
# volume as a patch of size 256 ** 3)
# this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be
# ideal because large initial patch sizes increase computation time because more iterations in the while loop
# further down may be required.
if len(spacing) == 3:
initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]
elif len(spacing) == 2:
initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]
else:
raise RuntimeError()
# clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that
# this is different from how nnU-Net v1 does it!
# todo patch size can still get too large because we pad the patch size to a multiple of 2**n
initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])])
# use that to get the network topology. Note that this changes the patch_size depending on the number of
# pooling operations (must be divisible by 2**num_pool in each axis)
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
norm = get_matching_instancenorm(unet_conv_op)
architecture_kwargs = {
'network_class_name': self.UNet_class.__module__ + '.' + self.UNet_class.__name__,
'arch_kwargs': {
'n_stages': num_stages,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'conv_op': unet_conv_op.__module__ + '.' + unet_conv_op.__name__,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
'conv_bias': True,
'norm_op': norm.__module__ + '.' + norm.__name__,
'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
'dropout_op': None,
'dropout_op_kwargs': None,
'nonlin': 'torch.nn.LeakyReLU',
'nonlin_kwargs': {'inplace': True},
},
'_kw_requires_import': ('conv_op', 'norm_op', 'dropout_op', 'nonlin'),
}
# now estimate vram consumption
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# how large is the reference for us here (batch size etc)?
# adapt for our vram target
reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \
(self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)
while estimate > reference:
# print(patch_size)
# patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the
# aspect ratio the most (that is the largest relative to median shape)
axis_to_be_reduced = np.argsort([i / j for i, j in zip(patch_size, median_shape[:len(spacing)])])[-1]
# we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this
# may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.
# If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size
# (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first
# subtract shape_must_be_divisible_by, then recompute it and then subtract the
# recomputed shape_must_be_divisible_by. Annoying.
patch_size = list(patch_size)
tmp = deepcopy(patch_size)
tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
_, _, _, _, shape_must_be_divisible_by = \
get_pool_and_conv_props(spacing, tmp,
self.UNet_featuremap_min_edge_length,
999999)
patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
# now recompute topology
network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,
self.UNet_featuremap_min_edge_length,
999999)
num_stages = len(pool_op_kernel_sizes)
architecture_kwargs['arch_kwargs'].update({
'n_stages': num_stages,
'kernel_sizes': conv_kernel_sizes,
'strides': pool_op_kernel_sizes,
'features_per_stage': _features_per_stage(num_stages, max_num_features),
'n_blocks_per_stage': self.UNet_blocks_per_stage_encoder[:num_stages],
'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
})
if _keygen(patch_size, pool_op_kernel_sizes) in _cache.keys():
estimate = _cache[_keygen(patch_size, pool_op_kernel_sizes)]
else:
estimate = self.static_estimate_VRAM_usage(
patch_size,
num_input_channels,
len(self.dataset_json['labels'].keys()),
architecture_kwargs['network_class_name'],
architecture_kwargs['arch_kwargs'],
architecture_kwargs['_kw_requires_import'],
)
_cache[_keygen(patch_size, pool_op_kernel_sizes)] = estimate
# alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was
# executed. If not, additional vram headroom is used to increase batch size
ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d
batch_size = round((reference / estimate) * ref_bs)
# we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot
# go smaller than self.UNet_min_batch_size though
bs_corresponding_to_5_percent = round(
approximate_n_voxels_dataset * self.max_dataset_covered / np.prod(patch_size, dtype=np.float64))
batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)
resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()
resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()
normalization_schemes, mask_is_used_for_norm = \
self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()
plan = {
'data_identifier': data_identifier,
'preprocessor_name': self.preprocessor_name,
'batch_size': batch_size,
'patch_size': patch_size,
'median_image_size_in_voxels': median_shape,
'spacing': spacing,
'normalization_schemes': normalization_schemes,
'use_mask_for_norm': mask_is_used_for_norm,
'resampling_fn_data': resampling_data.__name__,
'resampling_fn_seg': resampling_seg.__name__,
'resampling_fn_data_kwargs': resampling_data_kwargs,
'resampling_fn_seg_kwargs': resampling_seg_kwargs,
'resampling_fn_probabilities': resampling_softmax.__name__,
'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,
'architecture': architecture_kwargs
}
return plan
class nnUNetPlannerResEncM(ResEncUNetPlanner):
"""
Target is ~9-11 GB VRAM max -> older Titan, RTX 2080ti
"""
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetMPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
if gpu_memory_target_in_gb != 8:
warnings.warn("WARNING: You are running nnUNetPlannerM with a non-standard gpu_memory_target_in_gb. "
f"Expected 8, got {gpu_memory_target_in_gb}."
"You should only see this warning if you modified this value intentionally!!")
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
self.UNet_class = ResidualEncoderUNet
self.UNet_vram_target_GB = gpu_memory_target_in_gb
self.UNet_reference_val_corresp_GB = 8
# this is supposed to give the same GPU memory requirement as the default nnU-Net
self.UNet_reference_val_3d = 680000000
self.UNet_reference_val_2d = 135000000
self.max_dataset_covered = 1
class nnUNetPlannerResEncL(ResEncUNetPlanner):
"""
Target is ~24 GB VRAM max -> RTX 4090, Titan RTX, Quadro 6000
"""
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
if gpu_memory_target_in_gb != 24:
warnings.warn("WARNING: You are running nnUNetPlannerL with a non-standard gpu_memory_target_in_gb. "
f"Expected 24, got {gpu_memory_target_in_gb}."
"You should only see this warning if you modified this value intentionally!!")
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
self.UNet_class = ResidualEncoderUNet
self.UNet_vram_target_GB = gpu_memory_target_in_gb
self.UNet_reference_val_corresp_GB = 24
self.UNet_reference_val_3d = 2100000000 # 1840000000
self.UNet_reference_val_2d = 380000000 # 352666667
self.max_dataset_covered = 1
class nnUNetPlannerResEncXL(ResEncUNetPlanner):
"""
Target is 40 GB VRAM max -> A100 40GB, RTX 6000 Ada Generation
"""
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 40,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetXLPlans',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
if gpu_memory_target_in_gb != 40:
warnings.warn("WARNING: You are running nnUNetPlannerXL with a non-standard gpu_memory_target_in_gb. "
f"Expected 40, got {gpu_memory_target_in_gb}."
"You should only see this warning if you modified this value intentionally!!")
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)
self.UNet_class = ResidualEncoderUNet
self.UNet_vram_target_GB = gpu_memory_target_in_gb
self.UNet_reference_val_corresp_GB = 40
self.UNet_reference_val_3d = 3600000000
self.UNet_reference_val_2d = 560000000
self.max_dataset_covered = 1
if __name__ == '__main__':
# we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),
conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None,
nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned
# from this one to match the regular nnunetplans more closely
net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512),
conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2),
n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3,
n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1),
conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None,
nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792
================================================
FILE: nnunetv2/experiment_planning/plan_and_preprocess_api.py
================================================
import warnings
from typing import List, Type, Optional, Tuple, Union
from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json
import nnunetv2
from nnunetv2.configuration import default_num_processes
from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
def extract_fingerprint_dataset(dataset_id: int,
fingerprint_extractor_class: Type[
DatasetFingerprintExtractor] = DatasetFingerprintExtractor,
num_processes: int = default_num_processes, check_dataset_integrity: bool = False,
clean: bool = True, verbose: bool = True,
show_progress_bar: bool = True):
"""
Returns the fingerprint as a dictionary (additionally to saving it)
"""
dataset_name = convert_id_to_dataset_name(dataset_id)
print(dataset_name)
if check_dataset_integrity:
verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes)
fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose)
if hasattr(fpe, 'show_progress_bar'):
fpe.show_progress_bar = show_progress_bar
return fpe.run(overwrite_existing=clean)
def extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor',
num_processes: int = default_num_processes, check_dataset_integrity: bool = False,
clean: bool = True, verbose: bool = True,
show_progress_bar: bool = True):
"""
clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where
we don't want to rerun fingerprint extraction every time.
"""
fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
fingerprint_extractor_class_name,
current_module="nnunetv2.experiment_planning")
for d in dataset_ids:
extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean,
verbose, show_progress_bar)
def plan_experiment_dataset(dataset_id: int,
experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner,
gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor',
overwrite_target_spacing: Optional[Tuple[float, ...]] = None,
overwrite_plans_name: Optional[str] = None) -> Tuple[dict, str]:
"""
overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!
"""
kwargs = {}
if overwrite_plans_name is not None:
kwargs['plans_name'] = overwrite_plans_name
if gpu_memory_target_in_gb is not None:
kwargs['gpu_memory_target_in_gb'] = gpu_memory_target_in_gb
planner = experiment_planner_class(dataset_id,
preprocessor_name=preprocess_class_name,
overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if
overwrite_target_spacing is not None else overwrite_target_spacing,
suppress_transpose=False, # might expose this later,
**kwargs
)
ret = planner.plan_experiment()
return ret, planner.plans_identifier
def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner',
gpu_memory_target_in_gb: float = None, preprocess_class_name: str = 'DefaultPreprocessor',
overwrite_target_spacing: Optional[Tuple[float, ...]] = None,
overwrite_plans_name: Optional[str] = None):
"""
overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!
"""
if experiment_planner_class_name == 'ExperimentPlanner':
print("\n############################\n"
"INFO: You are using the old nnU-Net default planner. We have updated our recommendations. "
"Please consider using those instead! "
"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md"
"\n############################\n")
experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
experiment_planner_class_name,
current_module="nnunetv2.experiment_planning")
plans_identifier = None
for d in dataset_ids:
_, plans_identifier = plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb,
preprocess_class_name,
overwrite_target_spacing, overwrite_plans_name)
return plans_identifier
def preprocess_dataset(dataset_id: int,
plans_identifier: str = 'nnUNetPlans',
configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),
num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),
verbose: bool = False,
show_progress_bar: bool = True) -> None:
if not isinstance(num_processes, list):
num_processes = list(num_processes)
if len(num_processes) == 1:
num_processes = num_processes * len(configurations)
if len(num_processes) != len(configurations):
raise RuntimeError(
f'The list provided with num_processes must either have len 1 or as many elements as there are '
f'configurations (see --help). Number of configurations: {len(configurations)}, length '
f'of num_processes: '
f'{len(num_processes)}')
dataset_name = convert_id_to_dataset_name(dataset_id)
print(f'Preprocessing dataset {dataset_name}')
plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')
plans_manager = PlansManager(plans_file)
for n, c in zip(num_processes, configurations):
print(f'Configuration: {c}...')
if c not in plans_manager.available_configurations:
print(
f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of "
f"dataset {dataset_name}. Skipping.")
continue
configuration_manager = plans_manager.get_configuration(c)
print(configuration_manager)
preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
if hasattr(preprocessor, 'show_progress_bar'):
preprocessor.show_progress_bar = show_progress_bar
preprocessor.run(dataset_id, c, plans_identifier, num_processes=n)
# copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no
# longer there (useful for compute cluster where only the preprocessed data is available)
from distutils.file_util import copy_file
maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'))
dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)
# only copy files that are newer than the ones already present
for k in dataset:
copy_file(dataset[k]['label'],
join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']),
update=True)
def preprocess(dataset_ids: List[int],
plans_identifier: str = 'nnUNetPlans',
configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),
num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),
verbose: bool = False,
show_progress_bar: bool = True):
for d in dataset_ids:
preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose, show_progress_bar)
================================================
FILE: nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py
================================================
from nnunetv2.configuration import default_num_processes
from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess
def _add_logging_args(parser):
parser.add_argument('--verbose', required=False, action='store_true',
help='Set this to print a lot of stuff. Useful for debugging.')
parser.add_argument('--no_pbar', required=False, action='store_true',
help='Set this flag to disable the progress bar. Recommended for cluster/HPC environments.')
def extract_fingerprint_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', nargs='+', type=int,
help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
"planning and preprocessing for these datasets. Can of course also be just one dataset")
parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',
help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '
'\'DatasetFingerprintExtractor\'.')
parser.add_argument('-np', type=int, default=default_num_processes, required=False,
help=f'[OPTIONAL] Number of processes used for fingerprint extraction. '
f'Default: {default_num_processes}')
parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for "
"each dataset!")
parser.add_argument("--clean", required=False, default=False, action="store_true",
help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '
'fingerprint already exists, the fingerprint extractor will not run.')
_add_logging_args(parser)
args, unrecognized_args = parser.parse_known_args()
extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose,
show_progress_bar=not args.no_pbar)
def plan_experiment_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', nargs='+', type=int,
help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
"planning and preprocessing for these datasets. Can of course also be just one dataset")
parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,
help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '
'\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. '
'It\'s an all in one solution now. Wuch. Such amazing.')
parser.add_argument('-gpu_memory_target', default=None, type=float, required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner '
'class default is used). Changing this will '
'affect patch and batch size and will '
'definitely affect your models performance! Only use this if you really know what you '
'are doing and NEVER use this without running the default nnU-Net first as a baseline.')
parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '
'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your '
'models performance! Only use this if you really know what you '
'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '
'configurations. Default: None [no changes]. Changing this will affect image size and '
'potentially patch and batch '
'size. This will definitely affect your models performance! Only use this if you really '
'know what you are doing and NEVER use this without running the default nnU-Net first '
'(as a baseline). Changing the target spacing for the other configurations is currently '
'not implemented. New target spacing must be a list of three numbers!')
parser.add_argument('-overwrite_plans_name', default=None, required=False,
help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or '
'-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '
'differently named plans file such that the nnunet default plans are not '
'overwritten. You will then need to specify your custom plans file with -p whenever '
'running other nnunet commands (training, inference etc)')
args, unrecognized_args = parser.parse_known_args()
plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing,
args.overwrite_plans_name)
def preprocess_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', nargs='+', type=int,
help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
"planning and preprocessing for these datasets. Can of course also be just one dataset")
parser.add_argument('-plans_name', default='nnUNetPlans', required=False,
help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated')
parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',
help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres '
'3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '
'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.')
parser.add_argument('-np', type=int, nargs='+', default=None, required=False,
help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then "
"this number of processes is used for all configurations specified with -c. If it's a "
"list of numbers this list must have as many elements as there are configurations. We "
"then iterate over zip(configs, num_processes) to determine then umber of processes "
"used for each configuration. More processes is always faster (up to the number of "
"threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't "
"know what that is then dont touch it, or at least don't increase it!). DANGER: More "
"often than not the number of processes that can be used is limited by the amount of "
"RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND "
"DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 "
"for 3d_fullres, 8 for 3d_lowres and 4 for everything else")
_add_logging_args(parser)
args, unrecognized_args = parser.parse_known_args()
if args.np is None:
default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8}
np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]
else:
np = args.np
preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose,
show_progress_bar=not args.no_pbar)
def plan_and_preprocess_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', nargs='+', type=int,
help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
"planning and preprocessing for these datasets. Can of course also be just one dataset")
parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',
help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '
'\'DatasetFingerprintExtractor\'.')
parser.add_argument('-npfp', type=int, default=8, required=False,
help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8')
parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for "
"each dataset!")
parser.add_argument('--no_pp', default=False, action='store_true', required=False,
help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no '
'preprocesing). Useful for debugging.')
parser.add_argument("--clean", required=False, default=False, action="store_true",
help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '
'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU '
'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!')
parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,
help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '
'\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. '
'It\'s an all in one solution now. Wuch. Such amazing.')
parser.add_argument('-gpu_memory_target', default=None, type=float, required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target (in GB). Default: None (=Planner '
'class default is used). Changing this will '
'affect patch and batch size and will '
'definitely affect your models performance! Only use this if you really know what you '
'are doing and NEVER use this without running the default nnU-Net first as a baseline.')
parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '
'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your '
'models performance! Only use this if you really know what you '
'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,
help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '
'configurations. Default: None [no changes]. Changing this will affect image size and '
'potentially patch and batch '
'size. This will definitely affect your models performance! Only use this if you really '
'know what you are doing and NEVER use this without running the default nnU-Net first '
'(as a baseline). Changing the target spacing for the other configurations is currently '
'not implemented. New target spacing must be a list of three numbers!')
parser.add_argument('-overwrite_plans_name', default=None, required=False,
help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, '
'-preprocessor_name or '
'-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '
'differently named plans file such that the nnunet default plans are not '
'overwritten. You will then need to specify your custom plans file with -p whenever '
'running other nnunet commands (training, inference etc)')
parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',
help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3d_fullres '
'3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '
'from 3d_fullres. Configurations that do not exist for some dataset will be skipped.')
parser.add_argument('-np', type=int, nargs='+', default=None, required=False,
help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then "
"this number of processes is used for all configurations specified with -c. If it's a "
"list of numbers this list must have as many elements as there are configurations. We "
"then iterate over zip(configs, num_processes) to determine then umber of processes "
"used for each configuration. More processes is always faster (up to the number of "
"threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't "
"know what that is then dont touch it, or at least don't increase it!). DANGER: More "
"often than not the number of processes that can be used is limited by the amount of "
"RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND "
"DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 "
"for 3d_fullres, 8 for 3d_lowres and 4 for everything else")
_add_logging_args(parser)
args = parser.parse_args()
# fingerprint extraction
print("Fingerprint extraction...")
extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose,
show_progress_bar=not args.no_pbar)
# experiment planning
print('Experiment planning...')
plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name,
args.overwrite_target_spacing, args.overwrite_plans_name)
# manage default np
if args.np is None:
default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8}
np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]
else:
np = args.np
# preprocessing
if not args.no_pp:
print('Preprocessing...')
preprocess(args.d, plans_identifier, args.c, np, args.verbose, show_progress_bar=not args.no_pbar)
if __name__ == '__main__':
plan_and_preprocess_entry()
================================================
FILE: nnunetv2/experiment_planning/plans_for_pretraining/__init__.py
================================================
================================================
FILE: nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py
================================================
import argparse
from typing import Union
from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, save_json
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
def move_plans_between_datasets(
source_dataset_name_or_id: Union[int, str],
target_dataset_name_or_id: Union[int, str],
source_plans_identifier: str,
target_plans_identifier: str = None):
source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id)
target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id)
if target_plans_identifier is None:
target_plans_identifier = source_plans_identifier
source_folder = join(nnUNet_preprocessed, source_dataset_name)
assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \
f"Run nnUNetv2_plan_and_preprocess for source dataset first!"
source_plans_file = join(source_folder, source_plans_identifier + '.json')
assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \
f"Expected file: {source_plans_file}"
source_plans = load_json(source_plans_file)
source_plans['dataset_name'] = target_dataset_name
# we need to change data_identifier to use target_plans_identifier
if target_plans_identifier != source_plans_identifier:
for c in source_plans['configurations'].keys():
if 'data_identifier' in source_plans['configurations'][c].keys():
old_identifier = source_plans['configurations'][c]["data_identifier"]
if old_identifier.startswith(source_plans_identifier):
new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):]
else:
new_identifier = target_plans_identifier + '_' + old_identifier
source_plans['configurations'][c]["data_identifier"] = new_identifier
# we need to change the reader writer class!
target_raw_data_dir = join(nnUNet_raw, target_dataset_name)
target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json'))
# we may need to change the reader/writer
# pick any file from the source dataset
dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json)
example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0]
rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True,
verbose=False)
source_plans["image_reader_writer"] = rw.__name__
if target_plans_identifier is not None:
source_plans["plans_name"] = target_plans_identifier
save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'),
sort_keys=False)
def entry_point_move_plans_between_datasets():
parser = argparse.ArgumentParser()
parser.add_argument('-s', type=str, required=True,
help='Source dataset name or id')
parser.add_argument('-t', type=str, required=True,
help='Target dataset name or id')
parser.add_argument('-sp', type=str, required=True,
help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the '
'identifier would be nnUNetPlans')
parser.add_argument('-tp', type=str, required=False, default=None,
help='Target plans identifier. Default is None meaning the source plans identifier will '
'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier '
'such as nnUNetPlans!!!')
args = parser.parse_args()
move_plans_between_datasets(args.s, args.t, args.sp, args.tp)
if __name__ == '__main__':
move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2')
================================================
FILE: nnunetv2/experiment_planning/verify_dataset_integrity.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from typing import Type
import numpy as np
import pandas as pd
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool:
rw = readerclass()
seg, properties = rw.read_seg(label_file)
found_labels = np.sort(pd.unique(seg.ravel())) # np.unique(seg)
unexpected_labels = [i for i in found_labels if i not in expected_labels]
if len(found_labels) == 0 and found_labels[0] == 0:
print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, '
'up to you.' % label_file)
if len(unexpected_labels) > 0:
print("Error: Unexpected labels found in file %s.\nExpected: %s\nFound: %s" % (label_file, expected_labels,
found_labels))
return False
return True
def check_cases(image_files: List[str], label_file: str, expected_num_channels: int,
readerclass: Type[BaseReaderWriter]) -> bool:
rw = readerclass()
ret = True
images, properties_image = rw.read_images(image_files)
segmentation, properties_seg = rw.read_seg(label_file)
# check for nans
if np.any(np.isnan(images)):
print(f'Images contain NaN pixel values. You need to fix that by '
f'replacing NaN values with something that makes sense for your images!\nImages:\n{image_files}')
ret = False
if np.any(np.isnan(segmentation)):
print(f'Segmentation contains NaN pixel values. You need to fix that.\nSegmentation:\n{label_file}')
ret = False
# check shapes
shape_image = images.shape[1:]
shape_seg = segmentation.shape[1:]
if shape_image != shape_seg:
print('Error: Shape mismatch between segmentation and corresponding images. \nShape images: %s. '
'\nShape seg: %s. \nImage files: %s. \nSeg file: %s\n' %
(shape_image, shape_seg, image_files, label_file))
ret = False
# check spacings
spacing_images = properties_image['spacing']
spacing_seg = properties_seg['spacing']
if not np.allclose(spacing_seg, spacing_images):
print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. '
'\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' %
(spacing_images, spacing_seg, image_files, label_file))
ret = False
# check modalities
if not len(images) == expected_num_channels:
print('Error: Unexpected number of modalities. \nExpected: %d. \nGot: %d. \nImages: %s\n'
% (expected_num_channels, len(images), image_files))
ret = False
# nibabel checks
if 'nibabel_stuff' in properties_image.keys():
# this image was read with NibabelIO
affine_image = properties_image['nibabel_stuff']['original_affine']
affine_seg = properties_seg['nibabel_stuff']['original_affine']
if not np.allclose(affine_image, affine_seg):
print('WARNING: Affine is not the same for image and seg! \nAffine image: %s \nAffine seg: %s\n'
'Image files: %s. \nSeg file: %s.\nThis can be a problem but doesn\'t have to be. Please run '
'nnUNetv2_plot_overlay_pngs to verify if everything is OK!\n'
% (affine_image, affine_seg, image_files, label_file))
# sitk checks
if 'sitk_stuff' in properties_image.keys():
# this image was read with SimpleITKIO
# spacing has already been checked, only check direction and origin
origin_image = properties_image['sitk_stuff']['origin']
origin_seg = properties_seg['sitk_stuff']['origin']
if not np.allclose(origin_image, origin_seg):
print('Warning: Origin mismatch between segmentation and corresponding images. \nOrigin images: %s. '
'\nOrigin seg: %s. \nImage files: %s. \nSeg file: %s\n' %
(origin_image, origin_seg, image_files, label_file))
direction_image = properties_image['sitk_stuff']['direction']
direction_seg = properties_seg['sitk_stuff']['direction']
if not np.allclose(direction_image, direction_seg):
print('Warning: Direction mismatch between segmentation and corresponding images. \nDirection images: %s. '
'\nDirection seg: %s. \nImage files: %s. \nSeg file: %s\n' %
(direction_image, direction_seg, image_files, label_file))
return ret
def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None:
"""
folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json
checks if the expected number of training cases and labels are present
for each case, if possible, checks whether the pixel grids are aligned
checks whether the labels really only contain values they should
:param folder:
:return:
"""
assert isfile(join(folder, "dataset.json")), f"There needs to be a dataset.json file in folder, folder={folder}"
dataset_json = load_json(join(folder, "dataset.json"))
if not 'dataset' in dataset_json.keys():
assert isdir(join(folder, "imagesTr")), f"There needs to be a imagesTr subfolder in folder, folder={folder}"
assert isdir(join(folder, "labelsTr")), f"There needs to be a labelsTr subfolder in folder, folder={folder}"
# make sure all required keys are there
dataset_keys = list(dataset_json.keys())
required_keys = ['labels', "channel_names", "numTraining", "file_ending"]
assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \
'\n\nRequired: \n%s\n\nPresent: \n%s\n\nMissing: ' \
'\n%s\n\nUnused by nnU-Net:\n%s' % \
(str(required_keys),
str(dataset_keys),
str([i for i in required_keys if i not in dataset_keys]),
str([i for i in dataset_keys if i not in required_keys]))
expected_num_training = dataset_json['numTraining']
num_modalities = len(dataset_json['channel_names'].keys()
if 'channel_names' in dataset_json.keys()
else dataset_json['modality'].keys())
file_ending = dataset_json['file_ending']
dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)
# check if the right number of training cases is present
assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \
'(%d). Found %d instead.\nExamples: %s' % \
(expected_num_training, len(dataset),
list(dataset.keys())[:5])
# check if corresponding labels are present
if 'dataset' in dataset_json.keys():
# just check if everything is there
ok = True
missing_images = []
missing_labels = []
for k in dataset:
for i in dataset[k]['images']:
if not isfile(i):
missing_images.append(i)
ok = False
if not isfile(dataset[k]['label']):
missing_labels.append(dataset[k]['label'])
ok = False
if not ok:
raise FileNotFoundError(f"Some expected files were missing. Make sure you are properly referencing them "
f"in the dataset.json. Or use imagesTr & labelsTr folders!\nMissing images:"
f"\n{missing_images}\n\nMissing labels:\n{missing_labels}")
else:
# old code that uses imagestr and labelstr folders
labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False)
label_identifiers = [i[:-len(file_ending)] for i in labelfiles]
labels_present = [i in label_identifiers for i in dataset.keys()]
missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]]
assert all(labels_present), f'not all training cases have a label file in labelsTr. Fix that. Missing: {missing}'
labelfiles = [v['label'] for v in dataset.values()]
image_files = [v['images'] for v in dataset.values()]
# no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause
# problems anyway
label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order'))
expected_labels = label_manager.all_labels
if label_manager.has_ignore_label:
expected_labels.append(label_manager.ignore_label)
labels_valid_consecutive = np.ediff1d(expected_labels) == 1
assert all(
labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'
# determine reader/writer class
reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0])
# check whether only the desired labels are present
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
result = p.starmap(
verify_labels,
zip(labelfiles, [reader_writer_class] * len(labelfiles), [expected_labels] * len(labelfiles))
)
if not all(result):
raise RuntimeError(
'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).')
# check whether shapes and spacings match between images and labels
result = p.starmap(
check_cases,
zip(image_files, labelfiles, [num_modalities] * expected_num_training,
[reader_writer_class] * expected_num_training)
)
if not all(result):
raise RuntimeError(
'Some images have errors. Please check text output above to see which one(s) and what\'s going on.')
# check for nans
# check all same orientation nibabel
print('\n####################')
print('verify_dataset_integrity Done. \nIf you didn\'t see any error messages then your dataset is most likely OK!')
print('####################\n')
if __name__ == "__main__":
# investigate geometry issues
example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0')
num_processes = 6
verify_dataset_integrity(example_folder, num_processes)
================================================
FILE: nnunetv2/imageio/__init__.py
================================================
================================================
FILE: nnunetv2/imageio/base_reader_writer.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Tuple, Union, List
import numpy as np
class BaseReaderWriter(ABC):
@staticmethod
def _check_all_same(input_list):
if len(input_list) == 1:
return True
else:
# compare all entries to the first
return np.allclose(input_list[0], input_list[1:])
@staticmethod
def _check_all_same_array(input_list):
# compare all entries to the first
for i in input_list[1:]:
if i.shape != input_list[0].shape or not np.allclose(i, input_list[0]):
return False
return True
@abstractmethod
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
"""
Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the
modalities (or color channels, or however you would like to call them) in its first axis, followed by the
spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)).
Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for
example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg
for exporting the predicted segmentations, so make sure you have everything you need in there!
IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray.
Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and
preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So
if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y
and c the spacing of z.
In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be
(999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224),
spacing=(999, 1, 1)
For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!)
:param image_fnames:
:return:
1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are
the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image).
2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a
is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
a=999 (largest spacing value! Make it larger than b and c)
"""
pass
@abstractmethod
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
"""
Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple
segmentations are not (yet?) allowed
If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))`
:param seg_fname:
:return:
1) a np.ndarray of shape (1, x, y, z) where x, y, z are
the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation).
2) a dictionary with metadata. This can be anything. BUT it HAS to include a {'spacing': (a, b, c)} where a
is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
a=999 (largest spacing value! Make it larger than b and c)
"""
pass
@abstractmethod
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
"""
Export the predicted segmentation to the desired file format. The given seg array will have the same shape and
orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-)
properties is the same dictionary you created during read_images/read_seg so you can use the information here
to restore metadata
IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape
1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation
to 2d via seg = seg[0])!
:param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)!
:param output_fname:
:param properties: the dictionary that you created in read_images (the ones this segmentation is based on).
Use this to restore metadata
:return:
"""
pass
================================================
FILE: nnunetv2/imageio/natural_image_reader_writer.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, List
import numpy as np
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from skimage import io
class NaturalImage2DIO(BaseReaderWriter):
"""
ONLY SUPPORTS 2D IMAGES!!!
"""
# there are surely more we could add here. Everything that can be read by skimage.io should be supported
supported_file_endings = [
'.png',
# '.jpg',
# '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps!
'.bmp',
'.tif'
]
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
images = []
for f in image_fnames:
npy_img = io.imread(f)
if npy_img.ndim == 3:
# rgb image, last dimension should be the color channel and the size of that channel should be 3
# (or 4 if we have alpha)
assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \
"dimension must have shape 3 or 4 " \
f"(RGB or RGBA). Image shape here is {npy_img.shape}"
# move RGB(A) to front, add additional dim so that we have shape (c, 1, X, Y), where c is either 3 or 4
images.append(npy_img.transpose((2, 0, 1))[:, None])
elif npy_img.ndim == 2:
# grayscale image
images.append(npy_img[None, None])
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': (999, 1, 1)}
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
return self.read_images((seg_fname, ))
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
io.imsave(output_fname, seg[0].astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), check_contrast=False)
if __name__ == '__main__':
images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',)
segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png'
imgio = NaturalImage2DIO()
img, props = imgio.read_images(images)
seg, segprops = imgio.read_seg(segmentation)
================================================
FILE: nnunetv2/imageio/nibabel_reader_writer.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple, Union, List
import numpy as np
from nibabel.orientations import io_orientation, axcodes2ornt, ornt_transform
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
import nibabel
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
class NibabelIO(BaseReaderWriter):
"""
Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be
consistent. This is of course considered properly in segmentation export as well.
IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!
"""
supported_file_endings = [
'.nii',
'.nii.gz',
]
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
images = []
original_affines = []
spacings_for_nnunet = []
for f in image_fnames:
nib_image = nibabel.load(f)
assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO'
original_affine = nib_image.affine
original_affines.append(original_affine)
# spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)
spacings_for_nnunet.append(
[float(i) for i in nib_image.header.get_zooms()[::-1]]
)
# transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.
images.append(nib_image.get_fdata().transpose((2, 1, 0))[None])
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same_array(original_affines):
print('WARNING! Not all input images have the same original_affines!')
print('Affines:')
print(original_affines)
print('Image files:')
print(image_fnames)
print(
'It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(spacings_for_nnunet):
print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '
'having the same affine')
print('spacings_for_nnunet:')
print(spacings_for_nnunet)
print('Image files:')
print(image_fnames)
raise RuntimeError()
dict = {
'nibabel_stuff': {
'original_affine': original_affines[0],
},
'spacing': spacings_for_nnunet[0]
}
return np.vstack(images, dtype=np.float32, casting='unsafe'), dict
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
return self.read_images((seg_fname,))
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
# revert transpose
seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False)
seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine'])
nibabel.save(seg_nib, output_fname)
class NibabelIOWithReorient(BaseReaderWriter):
"""
Reorients images to RAS
Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be
consistent. This is of course considered properly in segmentation export as well.
IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!
"""
supported_file_endings = [
'.nii',
'.nii.gz',
]
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
images = []
original_affines = []
reoriented_affines = []
spacings_for_nnunet = []
for f in image_fnames:
nib_image = nibabel.load(f)
assert nib_image.ndim == 3, 'only 3d images are supported by NibabelIO'
original_affine = nib_image.affine
reoriented_image = nib_image.as_reoriented(io_orientation(original_affine))
reoriented_affine = reoriented_image.affine
original_affines.append(original_affine)
reoriented_affines.append(reoriented_affine)
# spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)
spacings_for_nnunet.append(
[float(i) for i in reoriented_image.header.get_zooms()[::-1]]
)
# transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.
images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None])
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same_array(reoriented_affines):
print('WARNING! Not all input images have the same reoriented_affines!')
print('Affines:')
print(reoriented_affines)
print('Image files:')
print(image_fnames)
print(
'It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(spacings_for_nnunet):
print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '
'having the same affine')
print('spacings_for_nnunet:')
print(spacings_for_nnunet)
print('Image files:')
print(image_fnames)
raise RuntimeError()
dict = {
'nibabel_stuff': {
'original_affine': original_affines[0],
'reoriented_affine': reoriented_affines[0],
},
'spacing': spacings_for_nnunet[0]
}
return np.vstack(images, dtype=np.float32, casting='unsafe'), dict
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
return self.read_images((seg_fname,))
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
# revert transpose
seg = seg.transpose((2, 1, 0)).astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False)
seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine'])
# Solution from https://github.com/nipy/nibabel/issues/1063#issuecomment-967124057
img_ornt = io_orientation(properties['nibabel_stuff']['original_affine'])
ras_ornt = axcodes2ornt("RAS")
from_canonical = ornt_transform(ras_ornt, img_ornt)
seg_nib_reoriented = seg_nib.as_reoriented(from_canonical)
if not np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine):
print(f'WARNING: Restored affine does not match original affine. File: {output_fname}')
print(f'Original affine\n', properties['nibabel_stuff']['original_affine'])
print(f'Restored affine\n', seg_nib_reoriented.affine)
nibabel.save(seg_nib_reoriented, output_fname)
if __name__ == '__main__':
img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/imagesTr/case_00004_0000.nii.gz'
seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset220_KiTS2023/labelsTr/case_00004.nii.gz'
nibio = NibabelIO()
# images, dct = nibio.read_images([img_file])
seg, dctseg = nibio.read_seg(seg_file)
nibio_r = NibabelIOWithReorient()
# images_r, dct_r = nibio_r.read_images([img_file])
seg_r, dctseg_r = nibio_r.read_seg(seg_file)
sitkio = SimpleITKIO()
# images_sitk, dct_sitk = sitkio.read_images([img_file])
seg_sitk, dctseg_sitk = sitkio.read_seg(seg_file)
# write reoriented and original segmentation
nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg)
nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r)
sitkio.write_seg(seg_sitk[0], '/home/isensee/seg_nibio_sitk.nii.gz', dctseg_sitk)
# now load all with sitk to make sure no shaped got f'd up
a, d1 = sitkio.read_seg('/home/isensee/seg_nibio.nii.gz')
b, d2 = sitkio.read_seg('/home/isensee/seg_nibio_r.nii.gz')
c, d3 = sitkio.read_seg('/home/isensee/seg_nibio_sitk.nii.gz')
assert a.shape == b.shape
assert b.shape == c.shape
assert np.all(a == b)
assert np.all(b == c)
================================================
FILE: nnunetv2/imageio/reader_writer_registry.py
================================================
import traceback
from typing import Type
from batchgenerators.utilities.file_and_folder_operations import join
import nnunetv2
from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO
from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from nnunetv2.imageio.tif_reader_writer import Tiff3DIO
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
LIST_OF_IO_CLASSES = [
NaturalImage2DIO,
SimpleITKIO,
Tiff3DIO,
NibabelIO,
NibabelIOWithReorient
]
def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None,
allow_nonmatching_filename: bool = False, verbose: bool = True
) -> Type[BaseReaderWriter]:
if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \
dataset_json_content['overwrite_image_reader_writer'] != 'None':
ioclass_name = dataset_json_content['overwrite_image_reader_writer']
# trying to find that class in the nnunetv2.imageio module
try:
ret = recursive_find_reader_writer_by_name(ioclass_name)
if verbose: print(f'Using {ret} reader/writer')
return ret
except RuntimeError:
if verbose: print(f'Warning: Unable to find ioclass specified in dataset.json: {ioclass_name}')
if verbose: print('Trying to automatically determine desired class')
return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file,
allow_nonmatching_filename, verbose)
def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False,
verbose: bool = True):
for rw in LIST_OF_IO_CLASSES:
if file_ending.lower() in rw.supported_file_endings:
if example_file is not None:
# if an example file is provided, try if we can actually read it. If not move on to the next reader
try:
tmp = rw()
_ = tmp.read_images((example_file,))
if verbose: print(f'Using {rw} as reader/writer')
return rw
except:
if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
traceback.print_exc()
pass
else:
if verbose: print(f'Using {rw} as reader/writer')
return rw
else:
if allow_nonmatching_filename and example_file is not None:
try:
tmp = rw()
_ = tmp.read_images((example_file,))
if verbose: print(f'Using {rw} as reader/writer')
return rw
except:
if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
if verbose: traceback.print_exc()
pass
raise RuntimeError(f"Unable to determine a reader for file ending {file_ending} and file {example_file} (file None means no file provided).")
def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]:
ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio')
if ret is None:
raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the "
"nnunetv2.imageio module." % rw_class_name)
else:
return ret
================================================
FILE: nnunetv2/imageio/readme.md
================================================
- Derive your adapter from `BaseReaderWriter`.
- Reimplement all abstractmethods.
- make sure to support 2d and 3d input images (or raise some error).
- place it in this folder or nnU-Net won't find it!
- add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py`
Bam, you're done!
================================================
FILE: nnunetv2/imageio/simpleitk_reader_writer.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, List
import numpy as np
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
import SimpleITK as sitk
class SimpleITKIO(BaseReaderWriter):
supported_file_endings = [
'.nii.gz',
'.nrrd',
'.mha',
'.gipl'
]
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
images = []
spacings = []
origins = []
directions = []
spacings_for_nnunet = []
for f in image_fnames:
itk_image = sitk.ReadImage(f)
spacings.append(itk_image.GetSpacing())
origins.append(itk_image.GetOrigin())
directions.append(itk_image.GetDirection())
npy_image = sitk.GetArrayFromImage(itk_image)
if npy_image.ndim == 2:
# 2d
npy_image = npy_image[None, None]
max_spacing = max(spacings[-1])
spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1]))
elif npy_image.ndim == 3:
# 3d, as in original nnunet
npy_image = npy_image[None]
spacings_for_nnunet.append(list(spacings[-1])[::-1])
elif npy_image.ndim == 4:
# 4d, multiple modalities in one file
spacings_for_nnunet.append(list(spacings[-1])[::-1][1:])
pass
else:
raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}")
images.append(npy_image)
spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1]))
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same(spacings):
print('ERROR! Not all input images have the same spacing!')
print('Spacings:')
print(spacings)
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same(origins):
print('WARNING! Not all input images have the same origin!')
print('Origins:')
print(origins)
print('Image files:')
print(image_fnames)
print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(directions):
print('WARNING! Not all input images have the same direction!')
print('Directions:')
print(directions)
print('Image files:')
print(image_fnames)
print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(spacings_for_nnunet):
print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a '
'bug. Please report!')
print('spacings_for_nnunet:')
print(spacings_for_nnunet)
print('Image files:')
print(image_fnames)
raise RuntimeError()
dict = {
'sitk_stuff': {
# this saves the sitk geometry information. This part is NOT used by nnU-Net!
'spacing': spacings[0],
'origin': origins[0],
'direction': directions[0]
},
# the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays
# are returned x,y,z but spacing is returned z,y,x. Duh.
'spacing': spacings_for_nnunet[0]
}
return np.vstack(images, dtype=np.float32, casting='unsafe'), dict
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
return self.read_images((seg_fname, ))
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y'
output_dimension = len(properties['sitk_stuff']['spacing'])
assert 1 < output_dimension < 4
if output_dimension == 2:
seg = seg[0]
itk_image = sitk.GetImageFromArray(seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False))
itk_image.SetSpacing(properties['sitk_stuff']['spacing'])
itk_image.SetOrigin(properties['sitk_stuff']['origin'])
itk_image.SetDirection(properties['sitk_stuff']['direction'])
sitk.WriteImage(itk_image, output_fname, True)
class SimpleITKIOWithReorient(SimpleITKIO):
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]], orientation = "RAS") -> Tuple[np.ndarray, dict]:
images = []
spacings = []
origins = []
directions = []
spacings_for_nnunet = []
for f in image_fnames:
itk_image = sitk.ReadImage(f)
original_orientation = sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection())
itk_image = sitk.DICOMOrient(itk_image, orientation)
# print(sitk.DICOMOrientImageFilter_GetOrientationFromDirectionCosines(itk_image.GetDirection()))
spacings.append(itk_image.GetSpacing())
origins.append(itk_image.GetOrigin())
directions.append(itk_image.GetDirection())
npy_image = sitk.GetArrayFromImage(itk_image)
if npy_image.ndim == 2:
# 2d
npy_image = npy_image[None, None]
max_spacing = max(spacings[-1])
spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1]))
elif npy_image.ndim == 3:
# 3d, as in original nnunet
npy_image = npy_image[None]
spacings_for_nnunet.append(list(spacings[-1])[::-1])
elif npy_image.ndim == 4:
# 4d, multiple modalities in one file
spacings_for_nnunet.append(list(spacings[-1])[::-1][1:])
pass
else:
raise RuntimeError(f"Unexpected number of dimensions: {npy_image.ndim} in file {f}")
images.append(npy_image)
spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1]))
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same(spacings):
print('ERROR! Not all input images have the same spacing!')
print('Spacings:')
print(spacings)
print('Image files:')
print(image_fnames)
raise RuntimeError()
if not self._check_all_same(origins):
print('WARNING! Not all input images have the same origin!')
print('Origins:')
print(origins)
print('Image files:')
print(image_fnames)
print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(directions):
print('WARNING! Not all input images have the same direction!')
print('Directions:')
print(directions)
print('Image files:')
print(image_fnames)
print('It is up to you to decide whether that\'s a problem. You should run nnUNetv2_plot_overlay_pngs to verify '
'that segmentations and data overlap.')
if not self._check_all_same(spacings_for_nnunet):
print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a '
'bug. Please report!')
print('spacings_for_nnunet:')
print(spacings_for_nnunet)
print('Image files:')
print(image_fnames)
raise RuntimeError()
dict = {
'sitk_stuff': {
# this saves the sitk geometry information. This part is NOT used by nnU-Net!
'spacing': spacings[0],
'origin': origins[0],
'direction': directions[0],
'original_orientation': original_orientation
},
# the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays
# are returned x,y,z but spacing is returned z,y,x. Duh.
'spacing': spacings_for_nnunet[0]
}
return np.vstack(images, dtype=np.float32, casting='unsafe'), dict
def write_seg(self, seg, output_fname, properties):
assert seg.ndim == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y'
output_dimension = len(properties['sitk_stuff']['spacing'])
assert 1 < output_dimension < 4
if output_dimension == 2:
seg = seg[0]
itk_image = sitk.GetImageFromArray(seg.astype(np.uint8, copy=False))
itk_image.SetSpacing(properties['sitk_stuff']['spacing'])
itk_image.SetOrigin(properties['sitk_stuff']['origin'])
itk_image.SetDirection(properties['sitk_stuff']['direction'])
itk_image = sitk.DICOMOrient(itk_image, properties['sitk_stuff']['original_orientation'])
sitk.WriteImage(itk_image, output_fname, True)
================================================
FILE: nnunetv2/imageio/tif_reader_writer.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from typing import Tuple, Union, List
import numpy as np
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
import tifffile
from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join
class Tiff3DIO(BaseReaderWriter):
"""
reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)!
If you have 2D tiffs, use NaturalImage2DIO
Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end
with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is
expected to be image1.json)!
"""
supported_file_endings = [
'.tif',
'.tiff',
]
def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
# figure out file ending used here
ending = '.' + image_fnames[0].split('.')[-1]
assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
ending_length = len(ending)
truncate_length = ending_length + 5 # 5 comes from len(_0000)
images = []
for f in image_fnames:
image = tifffile.imread(f)
if image.ndim != 3:
raise RuntimeError(f"Only 3D images are supported! File: {f}")
images.append(image[None])
# see if aux file can be found
expected_aux_file = image_fnames[0][:-truncate_length] + '.json'
if isfile(expected_aux_file):
spacing = load_json(expected_aux_file)['spacing']
assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'
else:
print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).')
spacing = (1, 1, 1)
if not self._check_all_same([i.shape for i in images]):
print('ERROR! Not all input images have the same shape!')
print('Shapes:')
print([i.shape for i in images])
print('Image files:')
print(image_fnames)
raise RuntimeError()
return np.vstack(images, dtype=np.float32, casting='unsafe'), {'spacing': spacing}
def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
# not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha
tifffile.imwrite(output_fname, data=seg.astype(np.uint8 if np.max(seg) < 255 else np.uint16, copy=False), compression='zlib')
file = os.path.basename(output_fname)
out_dir = os.path.dirname(output_fname)
ending = file.split('.')[-1]
save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json'))
def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
# figure out file ending used here
ending = '.' + seg_fname.split('.')[-1]
assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
ending_length = len(ending)
seg = tifffile.imread(seg_fname)
if seg.ndim != 3:
raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}")
seg = seg[None]
# see if aux file can be found
expected_aux_file = seg_fname[:-ending_length] + '.json'
if isfile(expected_aux_file):
spacing = load_json(expected_aux_file)['spacing']
assert len(spacing) == 3, f'spacing must have 3 entries, one for each dimension of the image. File: {expected_aux_file}'
assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}"
else:
print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).')
spacing = (1, 1, 1)
return seg.astype(np.float32, copy=False), {'spacing': spacing}
================================================
FILE: nnunetv2/inference/JHU_inference.py
================================================
import argparse
import multiprocessing
import os
from time import sleep
from typing import Union
import numpy as np
import torch
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle, join, maybe_mkdir_p, subdirs
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.export_prediction import convert_predicted_logits_to_segmentation_with_correct_shape
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.helpers import empty_cache
from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager
def export_prediction_from_logits_singleFiles(
predicted_array_or_file: Union[np.ndarray, torch.Tensor],
properties_dict: dict,
configuration_manager: ConfigurationManager,
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str],
output_file_truncated: str,
save_probabilities: bool = False):
"""
This function generates the output structure expected by the JHU benchmark. We interpret output_file_truncated
as the output folder. We create 'predictions' subfolders and populate them with the label maps
"""
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
ret = convert_predicted_logits_to_segmentation_with_correct_shape(
predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
return_probabilities=save_probabilities
)
del predicted_array_or_file
# save
if save_probabilities:
segmentation_final, probabilities_final = ret
np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)
save_pickle(properties_dict, output_file_truncated + '.pkl')
del probabilities_final, ret
else:
segmentation_final = ret
del ret
rw = plans_manager.image_reader_writer_class()
output_folder = join(output_file_truncated, 'predictions')
maybe_mkdir_p(output_folder)
label_name_dict = {j: i for i, j in label_manager.label_dict.items()}
for l in label_manager.foreground_labels:
label_name = label_name_dict[l]
rw.write_seg(
(segmentation_final == l).astype(np.uint8, copy=False),
join(output_folder, label_name + dataset_json_dict_or_file['file_ending']),
properties_dict
)
class JHUPredictor(nnUNetPredictor):
def predict_from_data_iterator(self,
data_iterator,
save_probabilities: bool = False,
num_processes_segmentation_export: int = default_num_processes):
"""
We replace export_prediction_from_logits with export_prediction_from_logits_singleFiles to comply with JHU
benchmark output format expectations
"""
with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
worker_list = [i for i in export_pool._pool]
r = []
for preprocessed in data_iterator:
data = preprocessed['data']
if isinstance(data, str):
delfile = data
data = torch.from_numpy(np.load(data))
os.remove(delfile)
ofile = preprocessed['ofile']
if ofile is not None:
print(f'\nPredicting {os.path.basename(ofile)}:')
else:
print(f'\nPredicting image of shape {data.shape}:')
print(f'perform_everything_on_device: {self.perform_everything_on_device}')
properties = preprocessed['data_properties']
# let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with
# npy files
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
while not proceed:
sleep(0.1)
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
prediction = self.predict_logits_from_preprocessed_data(data).cpu()
if ofile is not None:
# this needs to go into background processes
# export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager,
# self.dataset_json, ofile, save_probabilities)
print('sending off prediction to background worker for resampling and export')
r.append(
export_pool.starmap_async(
export_prediction_from_logits_singleFiles,
((prediction, properties, self.configuration_manager, self.plans_manager,
self.dataset_json, ofile, save_probabilities),)
)
)
else:
# convert_predicted_logits_to_segmentation_with_correct_shape(
# prediction, self.plans_manager,
# self.configuration_manager, self.label_manager,
# properties,
# save_probabilities)
print('sending off prediction to background worker for resampling')
r.append(
export_pool.starmap_async(
convert_predicted_logits_to_segmentation_with_correct_shape, (
(prediction, self.plans_manager,
self.configuration_manager, self.label_manager,
properties,
save_probabilities),)
)
)
if ofile is not None:
print(f'done with {os.path.basename(ofile)}')
else:
print(f'\nDone with image of shape {data.shape}:')
ret = [i.get()[0] for i in r]
if isinstance(data_iterator, MultiThreadedAugmenter):
data_iterator._finish()
# clear lru cache
compute_gaussian.cache_clear()
# clear device cache
empty_cache(self.device)
return ret
if __name__ == '__main__':
# python nnunetv2/inference/JHU_inference.py /home/isensee/Downloads/AbdomenAtlasTest /home/isensee/Downloads/AbdomenAtlasTest_pred -model /home/isensee/temp/JHU/trained_model_ep3850
# /home/isensee/temp/JHU/trained_model_ep3850
# /home/isensee/Downloads/AbdomenAtlasTest
# /home/isensee/Downloads/AbdomenAtlasTest_pred
os.environ['nnUNet_compile'] = 'f'
parser = argparse.ArgumentParser()
parser.add_argument('input_dir', type=str)
parser.add_argument('output_dir', type=str)
parser.add_argument('-model', required=True, type=str)
parser.add_argument('--disable_tqdm', required=False, action='store_true', default=False)
args = parser.parse_args()
predictor = JHUPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=not args.disable_tqdm
)
predictor.initialize_from_trained_model_folder(
args.model,
('all', ),
'checkpoint_final.pth'
)
# we need to create list of list of input files
input_caseids = subdirs(args.input_dir, join=False)
input_files = [[join(args.input_dir, i, 'ct.nii.gz')] for i in input_caseids]
output_folders = [join(args.output_dir, i) for i in input_caseids]
predictor.predict_from_files(
input_files,
output_folders,
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=2,
num_processes_segmentation_export=3,
folder_with_segs_from_prev_stage=None,
num_parts=1,
part_id=0
)
================================================
FILE: nnunetv2/inference/__init__.py
================================================
================================================
FILE: nnunetv2/inference/data_iterators.py
================================================
import multiprocessing
import queue
from torch.multiprocessing import Event, Queue, Manager
from time import sleep
from typing import Union, List
import numpy as np
import torch
from batchgenerators.dataloading.data_loader import DataLoader
from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]],
list_of_segs_from_prev_stage_files: Union[None, List[str]],
output_filenames_truncated: Union[None, List[str]],
plans_manager: PlansManager,
dataset_json: dict,
configuration_manager: ConfigurationManager,
target_queue: Queue,
done_event: Event,
abort_event: Event,
verbose: bool = False):
try:
label_manager = plans_manager.get_label_manager(dataset_json)
preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
for idx in range(len(list_of_lists)):
data, seg, data_properties = preprocessor.run_case(list_of_lists[idx],
list_of_segs_from_prev_stage_files[
idx] if list_of_segs_from_prev_stage_files is not None else None,
plans_manager,
configuration_manager,
dataset_json)
if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None:
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
item = {'data': data, 'data_properties': data_properties,
'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None}
success = False
while not success:
try:
if abort_event.is_set():
return
target_queue.put(item, timeout=0.01)
success = True
except queue.Full:
pass
done_event.set()
except Exception as e:
# print(Exception, e)
abort_event.set()
raise e
def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]],
list_of_segs_from_prev_stage_files: Union[None, List[str]],
output_filenames_truncated: Union[None, List[str]],
plans_manager: PlansManager,
dataset_json: dict,
configuration_manager: ConfigurationManager,
num_processes: int,
pin_memory: bool = False,
verbose: bool = False):
context = multiprocessing.get_context('spawn')
manager = Manager()
num_processes = min(len(list_of_lists), num_processes)
assert num_processes >= 1
processes = []
done_events = []
target_queues = []
abort_event = manager.Event()
for i in range(num_processes):
event = manager.Event()
queue = manager.Queue(maxsize=1)
pr = context.Process(target=preprocess_fromfiles_save_to_queue,
args=(
list_of_lists[i::num_processes],
list_of_segs_from_prev_stage_files[
i::num_processes] if list_of_segs_from_prev_stage_files is not None else None,
output_filenames_truncated[
i::num_processes] if output_filenames_truncated is not None else None,
plans_manager,
dataset_json,
configuration_manager,
queue,
event,
abort_event,
verbose
), daemon=True)
pr.start()
target_queues.append(queue)
done_events.append(event)
processes.append(pr)
worker_ctr = 0
while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):
# import IPython;IPython.embed()
if not target_queues[worker_ctr].empty():
item = target_queues[worker_ctr].get()
worker_ctr = (worker_ctr + 1) % num_processes
else:
all_ok = all(
[i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()
if not all_ok:
raise RuntimeError('Background workers died. Look for the error message further up! If there is '
'none then your RAM was full and the worker was killed by the OS. Use fewer '
'workers or get more RAM in that case!')
sleep(0.01)
continue
if pin_memory:
[i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]
yield item
[p.join() for p in processes]
class PreprocessAdapter(DataLoader):
def __init__(self, list_of_lists: List[List[str]],
list_of_segs_from_prev_stage_files: Union[None, List[str]],
preprocessor: DefaultPreprocessor,
output_filenames_truncated: Union[None, List[str]],
plans_manager: PlansManager,
dataset_json: dict,
configuration_manager: ConfigurationManager,
num_threads_in_multithreaded: int = 1):
self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \
preprocessor, plans_manager, configuration_manager, dataset_json
self.label_manager = plans_manager.get_label_manager(dataset_json)
if list_of_segs_from_prev_stage_files is None:
list_of_segs_from_prev_stage_files = [None] * len(list_of_lists)
if output_filenames_truncated is None:
output_filenames_truncated = [None] * len(list_of_lists)
super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)),
1, num_threads_in_multithreaded,
seed_for_shuffle=1, return_incomplete=True,
shuffle=False, infinite=False, sampling_probabilities=None)
self.indices = list(range(len(list_of_lists)))
def generate_train_batch(self):
idx = self.get_indices()[0]
files, seg_prev_stage, ofile = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
data, seg, data_properties = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager,
self.configuration_manager,
self.dataset_json)
if seg_prev_stage is not None:
seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))
data = torch.from_numpy(data)
return {'data': data, 'data_properties': data_properties, 'ofile': ofile}
class PreprocessAdapterFromNpy(DataLoader):
def __init__(self, list_of_images: List[np.ndarray],
list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
list_of_image_properties: List[dict],
truncated_ofnames: Union[List[str], None],
plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager,
num_threads_in_multithreaded: int = 1, verbose: bool = False):
preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \
preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames
self.label_manager = plans_manager.get_label_manager(dataset_json)
if list_of_segs_from_prev_stage is None:
list_of_segs_from_prev_stage = [None] * len(list_of_images)
if truncated_ofnames is None:
truncated_ofnames = [None] * len(list_of_images)
super().__init__(
list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)),
1, num_threads_in_multithreaded,
seed_for_shuffle=1, return_incomplete=True,
shuffle=False, infinite=False, sampling_probabilities=None)
self.indices = list(range(len(list_of_images)))
def generate_train_batch(self):
idx = self.get_indices()[0]
image, seg_prev_stage, props, ofname = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
data, seg, props = self.preprocessor.run_case_npy(image, seg_prev_stage, props,
self.plans_manager,
self.configuration_manager,
self.dataset_json)
if seg_prev_stage is not None:
seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))
data = torch.from_numpy(data)
return {'data': data, 'data_properties': props, 'ofile': ofname}
def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray],
list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
list_of_image_properties: List[dict],
truncated_ofnames: Union[List[str], None],
plans_manager: PlansManager,
dataset_json: dict,
configuration_manager: ConfigurationManager,
target_queue: Queue,
done_event: Event,
abort_event: Event,
verbose: bool = False):
try:
label_manager = plans_manager.get_label_manager(dataset_json)
preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
for idx in range(len(list_of_images)):
data, seg, props = preprocessor.run_case_npy(list_of_images[idx],
list_of_segs_from_prev_stage[
idx] if list_of_segs_from_prev_stage is not None else None,
list_of_image_properties[idx],
plans_manager,
configuration_manager,
dataset_json)
list_of_image_properties[idx] = props
if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None:
seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))
data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
item = {'data': data, 'data_properties': list_of_image_properties[idx],
'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None}
success = False
while not success:
try:
if abort_event.is_set():
return
target_queue.put(item, timeout=0.01)
success = True
except queue.Full:
pass
done_event.set()
except Exception as e:
abort_event.set()
raise e
def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray],
list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
list_of_image_properties: List[dict],
truncated_ofnames: Union[List[str], None],
plans_manager: PlansManager,
dataset_json: dict,
configuration_manager: ConfigurationManager,
num_processes: int,
pin_memory: bool = False,
verbose: bool = False):
context = multiprocessing.get_context('spawn')
manager = Manager()
num_processes = min(len(list_of_images), num_processes)
assert num_processes >= 1
target_queues = []
processes = []
done_events = []
abort_event = manager.Event()
for i in range(num_processes):
event = manager.Event()
queue = manager.Queue(maxsize=1)
pr = context.Process(target=preprocess_fromnpy_save_to_queue,
args=(
list_of_images[i::num_processes],
list_of_segs_from_prev_stage[
i::num_processes] if list_of_segs_from_prev_stage is not None else None,
list_of_image_properties[i::num_processes],
truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None,
plans_manager,
dataset_json,
configuration_manager,
queue,
event,
abort_event,
verbose
), daemon=True)
pr.start()
done_events.append(event)
processes.append(pr)
target_queues.append(queue)
worker_ctr = 0
while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):
if not target_queues[worker_ctr].empty():
item = target_queues[worker_ctr].get()
worker_ctr = (worker_ctr + 1) % num_processes
else:
all_ok = all(
[i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()
if not all_ok:
raise RuntimeError('Background workers died. Look for the error message further up! If there is '
'none then your RAM was full and the worker was killed by the OS. Use fewer '
'workers or get more RAM in that case!')
sleep(0.01)
continue
if pin_memory:
[i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]
yield item
[p.join() for p in processes]
================================================
FILE: nnunetv2/inference/examples.py
================================================
if __name__ == '__main__':
from nnunetv2.paths import nnUNet_results, nnUNet_raw
import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
# nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction
# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
# variant 1: give input and output folders
predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
save_probabilities=False, overwrite=False,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
# variant 2, use list of files as inputs. Note how we use nested lists!!!
indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')
predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
[join(indir, 'liver_142_0000.nii.gz')]],
[join(outdir, 'liver_152.nii.gz'),
join(outdir, 'liver_142.nii.gz')],
save_probabilities=False, overwrite=True,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
# variant 2.5, returns segmentations
indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
[join(indir, 'liver_142_0000.nii.gz')]],
None,
save_probabilities=True, overwrite=True,
num_processes_preprocessing=2,
num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1,
part_id=0)
# predict several npy images
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
# we do not set output files so that the segmentations will be returned. You can of course also specify output
# files instead (no return value on that case)
ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],
None,
[props, props2, props3, props4],
None, 2, save_probabilities=False,
num_processes_segmentation_export=2)
# predict a single numpy array
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
ret = predictor.predict_single_npy_array(img, props, None, None, True)
# custom iterator
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
# each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!
# If 'ofile' is None, the result will be returned instead of written to a file
# the iterator is responsible for performing the correct preprocessing!
# note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!
# take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays
# (they both use predictor.predict_from_data_iterator) for inspiration!
def my_iterator(list_of_input_arrs, list_of_input_props):
preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
for a, p in zip(list_of_input_arrs, list_of_input_props):
data, seg, p = preprocessor.run_case_npy(a,
None,
p,
predictor.plans_manager,
predictor.configuration_manager,
predictor.dataset_json)
yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}
ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),
save_probabilities=False, num_processes_segmentation_export=3)
================================================
FILE: nnunetv2/inference/export_prediction.py
================================================
from typing import Union, List
import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import insert_crop_into_image
from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle
from nnunetv2.configuration import default_num_processes
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],
plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
label_manager: LabelManager,
properties_dict: dict,
return_probabilities: bool = False,
num_threads_torch: int = default_num_processes):
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
# resample to original shape
spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
[properties_dict['spacing'][i] for i in plans_manager.transpose_forward])
# return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
# apply_inference_nonlin will convert to torch
if not return_probabilities:
# this has a faster computation path becasue we can skip the softmax in regular (not region based) trainig
segmentation = label_manager.convert_logits_to_segmentation(predicted_logits)
else:
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
del predicted_logits
# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
segmentation_reverted_cropping = insert_crop_into_image(segmentation_reverted_cropping, segmentation, properties_dict['bbox_used_for_cropping'])
del segmentation
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation_reverted_cropping, torch.Tensor):
segmentation_reverted_cropping = segmentation_reverted_cropping.cpu().numpy()
# revert transpose
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)
if return_probabilities:
# revert cropping
predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities,
properties_dict[
'bbox_used_for_cropping'],
properties_dict[
'shape_before_cropping'])
predicted_probabilities = predicted_probabilities.cpu().numpy()
# revert transpose
predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in
plans_manager.transpose_backward])
torch.set_num_threads(old_threads)
return segmentation_reverted_cropping, predicted_probabilities
else:
torch.set_num_threads(old_threads)
return segmentation_reverted_cropping
def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,
configuration_manager: ConfigurationManager,
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
save_probabilities: bool = False,
num_threads_torch: int = default_num_processes):
# if isinstance(predicted_array_or_file, str):
# tmp = deepcopy(predicted_array_or_file)
# if predicted_array_or_file.endswith('.npy'):
# predicted_array_or_file = np.load(predicted_array_or_file)
# elif predicted_array_or_file.endswith('.npz'):
# predicted_array_or_file = np.load(predicted_array_or_file)['softmax']
# os.remove(tmp)
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
ret = convert_predicted_logits_to_segmentation_with_correct_shape(
predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
return_probabilities=save_probabilities, num_threads_torch=num_threads_torch
)
del predicted_array_or_file
# save
if save_probabilities:
segmentation_final, probabilities_final = ret
np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)
save_pickle(properties_dict, output_file_truncated + '.pkl')
del probabilities_final, ret
else:
segmentation_final = ret
del ret
rw = plans_manager.image_reader_writer_class()
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
properties_dict)
def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str,
plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict,
dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes,
dataset_class=None) \
-> None:
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]
# resample to original shape
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[spacing_transposed[0], *configuration_manager.spacing]
predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted,
target_shape,
current_spacing,
target_spacing)
# create segmentation (argmax, regions, etc)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file)
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
if dataset_class is None:
nnUNetDatasetBlosc2.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file)
else:
dataset_class.save_seg(segmentation.astype(dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16), output_file)
torch.set_num_threads(old_threads)
================================================
FILE: nnunetv2/inference/predict_from_raw_data.py
================================================
import inspect
import itertools
import multiprocessing
import os
from copy import deepcopy
from queue import Queue
from threading import Thread
from time import sleep
from typing import Tuple, Union, List, Optional
import numpy as np
import torch
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
save_json
from torch import nn
from torch._dynamo import OptimizedModule
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
import nnunetv2
from nnunetv2.configuration import default_num_processes
from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \
preprocessing_iterator_fromnpy
from nnunetv2.inference.export_prediction import export_prediction_from_logits, \
convert_predicted_logits_to_segmentation_with_correct_shape
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \
compute_steps_for_sliding_window
from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder
class nnUNetPredictor(object):
def __init__(self,
tile_step_size: float = 0.5,
use_gaussian: bool = True,
use_mirroring: bool = True,
perform_everything_on_device: bool = True,
device: torch.device = torch.device('cuda'),
verbose: bool = False,
verbose_preprocessing: bool = False,
allow_tqdm: bool = True):
self.verbose = verbose
self.verbose_preprocessing = verbose_preprocessing
self.allow_tqdm = allow_tqdm
self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \
self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None
self.tile_step_size = tile_step_size
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
else:
print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False')
perform_everything_on_device = False
self.device = device
self.perform_everything_on_device = perform_everything_on_device
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth'):
"""
This is used when making predictions with a trained model
"""
if use_folds is None:
use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
plans = load_json(join(model_training_output_dir, 'plans.json'))
plans_manager = PlansManager(plans)
if isinstance(use_folds, str):
use_folds = [use_folds]
parameters = []
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
map_location=torch.device('cpu'), weights_only=False)
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
parameters.append(checkpoint['network_weights'])
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
# initialize network with first set of parameters, also see https://github.com/MIC-DKFZ/nnUNet/issues/2520
network.load_state_dict(parameters[0])
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)
def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],
dataset_json: dict, trainer_name: str,
inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]):
"""
This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
"""
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
allow_compile = True
allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (
os.environ['nnUNet_compile'].lower() in ('true', '1', 't'))
allow_compile = allow_compile and not isinstance(self.network, OptimizedModule)
if isinstance(self.network, DistributedDataParallel):
allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule)
if allow_compile:
print('Using torch.compile')
self.network = torch.compile(self.network)
@staticmethod
def auto_detect_available_folds(model_training_output_dir, checkpoint_name):
print('use_folds is None, attempting to auto detect available folds')
fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False)
fold_folders = [i for i in fold_folders if i != 'fold_all']
fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))]
use_folds = [int(i.split('_')[-1]) for i in fold_folders]
print(f'found the following folds: {use_folds}')
return use_folds
def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]],
folder_with_segs_from_prev_stage: str = None,
overwrite: bool = True,
part_id: int = 0,
num_parts: int = 1,
save_probabilities: bool = False):
if isinstance(list_of_lists_or_source_folder, str):
list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder,
self.dataset_json['file_ending'])
print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder')
list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]
caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in
list_of_lists_or_source_folder]
print(
f'I am processing {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)')
print(f'There are {len(caseids)} cases that I would like to predict')
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids]
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_filename_truncated = output_folder_or_list_of_truncated_output_files[part_id::num_parts]
else:
output_filename_truncated = None
seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if
folder_with_segs_from_prev_stage is not None else None for i in caseids]
# remove already predicted files from the lists
if not overwrite and output_filename_truncated is not None:
tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated]
if save_probabilities:
tmp2 = [isfile(i + '.npz') for i in output_filename_truncated]
tmp = [i and j for i, j in zip(tmp, tmp2)]
not_existing_indices = [i for i, j in enumerate(tmp) if not j]
output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices]
list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices]
seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices]
print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. '
f'That\'s {len(not_existing_indices)} cases.')
return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files
def predict_from_files(self,
list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
save_probabilities: bool = False,
overwrite: bool = True,
num_processes_preprocessing: int = default_num_processes,
num_processes_segmentation_export: int = default_num_processes,
folder_with_segs_from_prev_stage: str = None,
num_parts: int = 1,
part_id: int = 0):
"""
This is nnU-Net's default function for making predictions. It works best for batch predictions
(predicting many images at once).
"""
assert part_id <= num_parts, ("Part ID must be smaller than num_parts. Remember that we start counting with 0. "
"So if there are 3 parts then valid part IDs are 0, 1, 2")
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_folder = output_folder_or_list_of_truncated_output_files
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
else:
output_folder = None
########################
# let's store the input arguments so that its clear what was used to generate the prediction
if output_folder is not None:
my_init_kwargs = {}
for k in inspect.signature(self.predict_from_files).parameters.keys():
my_init_kwargs[k] = locals()[k]
my_init_kwargs = deepcopy(
my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
recursive_fix_for_json_export(my_init_kwargs)
maybe_mkdir_p(output_folder)
save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))
# we need these two if we want to do things with the predictions like for example apply postprocessing
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
#######################
# check if we need a prediction from the previous stage
if self.configuration_manager.previous_stage_name is not None:
assert folder_with_segs_from_prev_stage is not None, \
f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
f' they are located via folder_with_segs_from_prev_stage'
# sort out input and output filenames
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
self._manage_input_and_output_lists(list_of_lists_or_source_folder,
output_folder_or_list_of_truncated_output_files,
folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
save_probabilities)
if len(list_of_lists_or_source_folder) == 0:
return
data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
seg_from_prev_stage_files,
output_filename_truncated,
num_processes_preprocessing)
return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
def _internal_get_data_iterator_from_lists_of_filenames(self,
input_list_of_lists: List[List[str]],
seg_from_prev_stage_files: Union[List[str], None],
output_filenames_truncated: Union[List[str], None],
num_processes: int):
return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files,
output_filenames_truncated, self.plans_manager, self.dataset_json,
self.configuration_manager, num_processes, self.device.type == 'cuda',
self.verbose_preprocessing)
# preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
# # hijack batchgenerators, yo
# # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This
# # way we don't have to reinvent the wheel here.
# num_processes = max(1, min(num_processes, len(input_list_of_lists)))
# ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor,
# output_filenames_truncated, self.plans_manager, self.dataset_json,
# self.configuration_manager, num_processes)
# if num_processes == 0:
# mta = SingleThreadedAugmenter(ppa, None)
# else:
# mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory)
# return mta
def get_data_iterator_from_raw_npy_data(self,
image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
np.ndarray,
List[
np.ndarray]],
properties_or_list_of_properties: Union[dict, List[dict]],
truncated_ofname: Union[str, List[str], None],
num_processes: int = 3):
list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \
image_or_list_of_images
if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray):
segs_from_prev_stage_or_list_of_segs_from_prev_stage = [
segs_from_prev_stage_or_list_of_segs_from_prev_stage]
if isinstance(truncated_ofname, str):
truncated_ofname = [truncated_ofname]
if isinstance(properties_or_list_of_properties, dict):
properties_or_list_of_properties = [properties_or_list_of_properties]
num_processes = min(num_processes, len(list_of_images))
pp = preprocessing_iterator_fromnpy(
list_of_images,
segs_from_prev_stage_or_list_of_segs_from_prev_stage,
properties_or_list_of_properties,
truncated_ofname,
self.plans_manager,
self.dataset_json,
self.configuration_manager,
num_processes,
self.device.type == 'cuda',
self.verbose_preprocessing
)
return pp
def predict_from_list_of_npy_arrays(self,
image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
np.ndarray,
List[
np.ndarray]],
properties_or_list_of_properties: Union[dict, List[dict]],
truncated_ofname: Union[str, List[str], None],
num_processes: int = 3,
save_probabilities: bool = False,
num_processes_segmentation_export: int = default_num_processes):
iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images,
segs_from_prev_stage_or_list_of_segs_from_prev_stage,
properties_or_list_of_properties,
truncated_ofname,
num_processes)
return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export)
def predict_from_data_iterator(self,
data_iterator,
save_probabilities: bool = False,
num_processes_segmentation_export: int = default_num_processes):
"""
each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!
If 'ofile' is None, the result will be returned instead of written to a file
"""
with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
worker_list = [i for i in export_pool._pool]
r = []
for preprocessed in data_iterator:
data = preprocessed['data']
if isinstance(data, str):
delfile = data
data = torch.from_numpy(np.load(data))
os.remove(delfile)
ofile = preprocessed['ofile']
if ofile is not None:
print(f'\nPredicting {os.path.basename(ofile)}:')
else:
print(f'\nPredicting image of shape {data.shape}:')
print(f'perform_everything_on_device: {self.perform_everything_on_device}')
properties = preprocessed['data_properties']
# let's not get into a runaway situation where the GPU predicts so fast that the disk has to be swamped with
# npy files
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
while not proceed:
sleep(0.1)
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
# convert to numpy to prevent uncatchable memory alignment errors from multiprocessing serialization of torch tensors
prediction = self.predict_logits_from_preprocessed_data(data).cpu().detach().numpy()
if ofile is not None:
print('sending off prediction to background worker for resampling and export')
r.append(
export_pool.starmap_async(
export_prediction_from_logits,
((prediction, properties, self.configuration_manager, self.plans_manager,
self.dataset_json, ofile, save_probabilities),)
)
)
else:
print('sending off prediction to background worker for resampling')
r.append(
export_pool.starmap_async(
convert_predicted_logits_to_segmentation_with_correct_shape, (
(prediction, self.plans_manager,
self.configuration_manager, self.label_manager,
properties,
save_probabilities),)
)
)
if ofile is not None:
print(f'done with {os.path.basename(ofile)}')
else:
print(f'\nDone with image of shape {data.shape}:')
ret = [i.get()[0] for i in r]
if isinstance(data_iterator, MultiThreadedAugmenter):
data_iterator._finish()
# clear lru cache
compute_gaussian.cache_clear()
# clear device cache
empty_cache(self.device)
return ret
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None,
output_file_truncated: str = None,
save_or_return_probabilities: bool = False):
"""
WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.
input_image: Make sure to load the image in the way nnU-Net expects! nnU-Net is trained on a certain axis
ordering which cannot be disturbed in inference,
otherwise you will get bad results. The easiest way to achieve that is to use the same I/O class
for loading images as was used during nnU-Net preprocessing! You can find that class in your
plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the
default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
image_properties must only have a 'spacing' key!
"""
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
[output_file_truncated],
self.plans_manager, self.dataset_json, self.configuration_manager,
num_threads_in_multithreaded=1, verbose=self.verbose)
if self.verbose:
print('preprocessing')
dct = next(ppa)
if self.verbose:
print('predicting')
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
if self.verbose:
print('resampling to original shape')
if output_file_truncated is not None:
export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,
self.plans_manager, self.dataset_json, output_file_truncated,
save_or_return_probabilities)
else:
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=
save_or_return_probabilities)
if save_or_return_probabilities:
return ret[0], ret[1]
else:
return ret
@torch.inference_mode()
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
SEE convert_predicted_logits_to_segmentation_with_correct_shape
"""
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
prediction = None
for params in self.list_of_parameters:
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
if self.verbose: print('Prediction done')
torch.set_num_threads(n_threads)
return prediction
def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
slicers = []
if len(self.configuration_manager.patch_size) < len(image_size):
assert len(self.configuration_manager.patch_size) == len(
image_size) - 1, 'if tile_size has less entries than image_size, ' \
'len(tile_size) ' \
'must be one shorter than len(image_size) ' \
'(only dimension ' \
'discrepancy of 1 allowed).'
steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,
self.tile_step_size)
if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is'
f' {image_size}, tile_size {self.configuration_manager.patch_size}, '
f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
for d in range(image_size[0]):
for sx in steps[0]:
for sy in steps[1]:
slicers.append(
tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
zip((sx, sy), self.configuration_manager.patch_size)]]))
else:
steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
self.tile_step_size)
if self.verbose: print(
f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, '
f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
for sx in steps[0]:
for sy in steps[1]:
for sz in steps[2]:
slicers.append(
tuple([slice(None), *[slice(si, si + ti) for si, ti in
zip((sx, sy, sz), self.configuration_manager.patch_size)]]))
return slicers
@torch.inference_mode()
def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
prediction = self.network(x)
if mirror_axes is not None:
# check for invalid numbers in mirror_axes
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'
mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)
return prediction
@torch.inference_mode()
def _internal_predict_sliding_window_return_logits(self,
data: torch.Tensor,
slicers,
do_on_device: bool = True,
):
predicted_logits = n_predictions = prediction = gaussian = workon = None
results_device = self.device if do_on_device else torch.device('cpu')
def producer(d, slh, q):
for s in slh:
q.put((torch.clone(d[s][None], memory_format=torch.contiguous_format).to(self.device), s))
q.put('end')
try:
empty_cache(self.device)
# move data to device
if self.verbose:
print(f'move image to device {results_device}')
data = data.to(results_device)
queue = Queue(maxsize=2)
t = Thread(target=producer, args=(data, slicers, queue))
t.start()
# preallocate arrays
if self.verbose:
print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
else:
gaussian = 1
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
with tqdm(desc=None, total=len(slicers), disable=not self.allow_tqdm) as pbar:
while True:
item = queue.get()
if item == 'end':
queue.task_done()
break
workon, sl = item
prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
if self.use_gaussian:
prediction *= gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian
queue.task_done()
pbar.update()
queue.join()
# predicted_logits /= n_predictions
torch.div(predicted_logits, n_predictions, out=predicted_logits)
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
except Exception as e:
del predicted_logits, n_predictions, prediction, gaussian, workon
empty_cache(self.device)
empty_cache(results_device)
raise e
return predicted_logits
@torch.inference_mode()
def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
-> Union[np.ndarray, torch.Tensor]:
assert isinstance(input_image, torch.Tensor)
self.network = self.network.to(self.device)
self.network.eval()
empty_cache(self.device)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
# and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'
if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False)
else:
predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device)
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits
def predict_from_files_sequential(self,
list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
save_probabilities: bool = False,
overwrite: bool = True,
folder_with_segs_from_prev_stage: str = None):
"""
Just like predict_from_files but doesn't use any multiprocessing. Slow, but sometimes necessary
"""
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_folder = output_folder_or_list_of_truncated_output_files
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
if len(output_folder) == 0: # just a file was given without a folder
output_folder = os.path.curdir
else:
output_folder = None
########################
# let's store the input arguments so that its clear what was used to generate the prediction
if output_folder is not None:
my_init_kwargs = {}
for k in inspect.signature(self.predict_from_files_sequential).parameters.keys():
my_init_kwargs[k] = locals()[k]
my_init_kwargs = deepcopy(
my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
recursive_fix_for_json_export(my_init_kwargs)
save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))
# we need these two if we want to do things with the predictions like for example apply postprocessing
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
#######################
# check if we need a prediction from the previous stage
if self.configuration_manager.previous_stage_name is not None:
assert folder_with_segs_from_prev_stage is not None, \
f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
f' they are located via folder_with_segs_from_prev_stage'
# sort out input and output filenames
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
self._manage_input_and_output_lists(list_of_lists_or_source_folder,
output_folder_or_list_of_truncated_output_files,
folder_with_segs_from_prev_stage, overwrite, 0, 1,
save_probabilities)
if len(list_of_lists_or_source_folder) == 0:
return
label_manager = self.plans_manager.get_label_manager(self.dataset_json)
preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose)
if output_filename_truncated is None:
output_filename_truncated = [None] * len(list_of_lists_or_source_folder)
if seg_from_prev_stage_files is None:
seg_from_prev_stage_files = [None] * len(seg_from_prev_stage_files)
ret = []
for li, of, sps in zip(list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files):
data, seg, data_properties = preprocessor.run_case(
li,
sps,
self.plans_manager,
self.configuration_manager,
self.dataset_json
)
print(f'perform_everything_on_device: {self.perform_everything_on_device}')
prediction = self.predict_logits_from_preprocessed_data(torch.from_numpy(data)).cpu()
if of is not None:
export_prediction_from_logits(prediction, data_properties, self.configuration_manager, self.plans_manager,
self.dataset_json, of, save_probabilities)
else:
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager,
self.configuration_manager, self.label_manager,
data_properties,
save_probabilities))
# clear lru cache
compute_gaussian.cache_clear()
# clear device cache
empty_cache(self.device)
return ret
def _getDefaultValue(env: str, dtype: type, default: any,) -> any:
try:
val = dtype(os.environ.get(env) or default)
except:
val = default
return val
def predict_entry_point_modelfolder():
import argparse
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
'you want to manually specify a folder containing a trained nnU-Net '
'model. This is useful when the nnunet environment variables '
'(nnUNet_results) are not set.')
parser.add_argument('-i', type=str, required=True,
help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
'File endings must be the same as the training dataset!')
parser.add_argument('-o', type=str, required=True,
help='Output folder. If it does not exist it will be created. Predicted segmentations will '
'have the same name as their source images.')
parser.add_argument('-m', type=str, required=True,
help='Folder in which the trained model is. Must have subfolders fold_X for the different '
'folds you trained')
parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
help='Specify the folds of the trained model that should be used for prediction. '
'Default: (0, 1, 2, 3, 4)')
parser.add_argument('-step_size', type=float, required=False, default=0.5,
help='Step size for sliding window prediction. The larger it is the faster but less accurate '
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
help='Set this to export predicted class "probabilities". Required if you want to ensemble '
'multiple configurations.')
parser.add_argument('--continue_prediction', '--c', action='store_true',
help='Continue an aborted previous prediction (will not overwrite existing files)')
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
parser.add_argument('-npp', type=int, required=False, default=3,
help='Number of processes used for preprocessing. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-nps', type=int, required=False, default=3,
help='Number of processes used for segmentation export. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
help='Folder containing the predictions of the previous stage. Required for cascaded models.')
parser.add_argument('-device', type=str, default='cuda', required=False,
help="Use this to set the device the inference should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '
'jobs)')
parser.add_argument('--not_on_device', action='store_true', required=False, default=False,
help="Set this flag to disable perform_everything_on_device. Recommended for large cases that "
"occupy more VRAM than available")
print(
"\n#######################################################################\nPlease cite the following paper "
"when using nnU-Net:\n"
"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
"Nature methods, 18(2), 203-211.\n#######################################################################\n")
args = parser.parse_args()
args.f = [i if i == 'all' else int(i) for i in args.f]
if not isdir(args.o):
maybe_mkdir_p(args.o)
assert args.device in ['cpu', 'cuda',
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
else:
device = torch.device('mps')
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_device=not args.not_on_device,
device=device,
verbose=args.verbose,
allow_tqdm=not args.disable_progress_bar,
verbose_preprocessing=args.verbose)
predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk)
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
num_parts=1, part_id=0)
def predict_entry_point():
import argparse
parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
'you want to manually specify a folder containing a trained nnU-Net '
'model. This is useful when the nnunet environment variables '
'(nnUNet_results) are not set.')
parser.add_argument('-i', type=str, required=True,
help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
'File endings must be the same as the training dataset!')
parser.add_argument('-o', type=str, required=True,
help='Output folder. If it does not exist it will be created. Predicted segmentations will '
'have the same name as their source images.')
parser.add_argument('-d', type=str, required=True,
help='Dataset with which you would like to predict. You can specify either dataset name or id')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='Plans identifier. Specify the plans in which the desired configuration is located. '
'Default: nnUNetPlans')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')
parser.add_argument('-c', type=str, required=True,
help='nnU-Net configuration that should be used for prediction. Config must be located '
'in the plans specified with -p')
parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
help='Specify the folds of the trained model that should be used for prediction. '
'Default: (0, 1, 2, 3, 4)')
parser.add_argument('-step_size', type=float, required=False, default=0.5,
help='Step size for sliding window prediction. The larger it is the faster but less accurate '
'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
'but less accurate inference. Not recommended.')
parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
"to be a good listener/reader.")
parser.add_argument('--save_probabilities', action='store_true',
help='Set this to export predicted class "probabilities". Required if you want to ensemble '
'multiple configurations.')
parser.add_argument('--continue_prediction', action='store_true',
help='Continue an aborted previous prediction (will not overwrite existing files)')
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3),
help='Number of processes used for preprocessing. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3),
help='Number of processes used for segmentation export. More is not always better. Beware of '
'out-of-RAM issues. Default: 3')
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
help='Folder containing the predictions of the previous stage. Required for cascaded models.')
parser.add_argument('-num_parts', type=int, required=False, default=1,
help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one '
'call predicts everything)')
parser.add_argument('-part_id', type=int, required=False, default=0,
help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with '
'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts '
'5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible '
'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)')
parser.add_argument('-device', type=str, default='cuda', required=False,
help="Use this to set the device the inference should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '
'jobs)')
parser.add_argument('--not_on_device', action='store_true', required=False, default=False,
help="Set this flag to disable perform_everything_on_device. Recommended for large cases that "
"occupy more VRAM than available")
print(
"\n#######################################################################\nPlease cite the following paper "
"when using nnU-Net:\n"
"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
"Nature methods, 18(2), 203-211.\n#######################################################################\n")
args = parser.parse_args()
args.f = [i if i == 'all' else int(i) for i in args.f]
model_folder = get_output_folder(args.d, args.tr, args.p, args.c)
if not isdir(args.o):
maybe_mkdir_p(args.o)
# slightly passive aggressive haha
assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.'
assert args.device in ['cpu', 'cuda',
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
else:
device = torch.device('mps')
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_device=not args.not_on_device,
device=device,
verbose=args.verbose,
verbose_preprocessing=args.verbose,
allow_tqdm=not args.disable_progress_bar)
predictor.initialize_from_trained_model_folder(
model_folder,
args.f,
checkpoint_name=args.chk
)
run_sequential = args.nps == 0 and args.npp == 0
if run_sequential:
print("Running in non-multiprocessing mode")
predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
folder_with_segs_from_prev_stage=args.prev_stage_predictions)
else:
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
num_parts=args.num_parts,
part_id=args.part_id)
# r = predict_from_raw_data(args.i,
# args.o,
# model_folder,
# args.f,
# args.step_size,
# use_gaussian=True,
# use_mirroring=not args.disable_tta,
# perform_everything_on_device=True,
# verbose=args.verbose,
# save_probabilities=args.save_probabilities,
# overwrite=not args.continue_prediction,
# checkpoint_name=args.chk,
# num_processes_preprocessing=args.npp,
# num_processes_segmentation_export=args.nps,
# folder_with_segs_from_prev_stage=args.prev_stage_predictions,
# num_parts=args.num_parts,
# part_id=args.part_id,
# device=device)
if __name__ == '__main__':
########################## predict a bunch of files
from nnunetv2.paths import nnUNet_results, nnUNet_raw
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)
predictor.initialize_from_trained_model_folder(
join(nnUNet_results, 'Dataset004_Hippocampus/nnUNetTrainer_5epochs__nnUNetPlans__3d_fullres'),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
# predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
# join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
# save_probabilities=False, overwrite=False,
# num_processes_preprocessing=2, num_processes_segmentation_export=2,
# folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
#
# # predict a numpy array
# from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
#
# img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
# ret = predictor.predict_single_npy_array(img, props, None, None, False)
#
# iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)
# ret = predictor.predict_from_data_iterator(iterator, False, 1)
ret = predictor.predict_from_files_sequential(
[['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_002_0000.nii.gz'], ['/media/isensee/raw_data/nnUNet_raw/Dataset004_Hippocampus/imagesTs/hippocampus_005_0000.nii.gz']],
'/home/isensee/temp/tmp', False, True, None
)
================================================
FILE: nnunetv2/inference/readme.md
================================================
The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into
your existing workflows.
This readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn
all the details!
# Preface
In terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on
the fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and
sends the prediction off to another set of background workers which will resize the resulting logits, convert
them to a segmentation and export the segmentation.
The reason the default setup is the best option is because
1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can
focus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing.
This uses your resources as well as possible!
2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images
and having to store all of them + the results in your system memory
# nnUNetPredictor
The new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your
code can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each
new prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal.
The nnUNetPredictor must be ininitialized manually! You will want to use the
`predictor.initialize_from_trained_model_folder` function for 99% of use cases!
New feature: If you do not specify an output folder / output files then the predicted segmentations will be
returned
## Recommended nnU-Net default: predict from source files
tldr:
- loads images on the fly
- performs preprocessing in background workers
- main process focuses only on making predictions
- results are again given to background workers for resampling and (optional) export
pros:
- best suited for predicting a large number of images
- nicer to your RAM
cons:
- not ideal when single images are to be predicted
- requires images to be present as files
Example:
```python
from nnunetv2.paths import nnUNet_results, nnUNet_raw
import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)
# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
# variant 1: give input and output folders
predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
save_probabilities=False, overwrite=False,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
```
Instead if giving input and output folders you can also give concrete files. If you give concrete files, there is no
need for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames!
Remember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted
and the inner list contains all the files belonging to that case. There is just one file for datasets with just one
input modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc).
IMPORTANT: the order in which the files for each case are given must match the order of the channels as defined in the
dataset.json!
If you give files as input, you need to give individual output files as output!
```python
# variant 2, use list of files as inputs. Note how we use nested lists!!!
indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')
predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
[join(indir, 'liver_142_0000.nii.gz')]],
[join(outdir, 'liver_152'),
join(outdir, 'liver_142')],
save_probabilities=False, overwrite=False,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
```
Did you know? If you do not specify output files, the predicted segmentations will be returned:
```python
# variant 2.5, returns segmentations
indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
[join(indir, 'liver_142_0000.nii.gz')]],
None,
save_probabilities=False, overwrite=True,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
```
## Prediction from npy arrays
tldr:
- you give images as a list of npy arrays
- performs preprocessing in background workers
- main process focuses only on making predictions
- results are again given to background workers for resampling and (optional) export
pros:
- the correct variant for when you have images in RAM already
- well suited for predicting multiple images
cons:
- uses more ram than the default
- unsuited for large number of images as all images must be held in RAM
```python
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
# we do not set output files so that the segmentations will be returned. You can of course also specify output
# files instead (no return value on that case)
ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],
None,
[props, props2, props3, props4],
None, 2, save_probabilities=False,
num_processes_segmentation_export=2)
```
## Predicting a single npy array
tldr:
- you give one image as npy array
- axes ordering must match the corresponding training data. The easiest way to achieve that is to use the same I/O class
for loading images as was used during nnU-Net preprocessing! You can find that class in your
plans.json file under the key "image_reader_writer". If you decide to freestyle, know that the
default axis ordering for medical images is the one from SimpleITK. If you load with nibabel,
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
- everything is done in the main process: preprocessing, prediction, resampling, (export)
- no interlacing, slowest variant!
- ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON
pros:
- no messing with multiprocessing
- no messing with data iterator blabla
cons:
- slows as heck, yo
- never the right choice unless you can only give a single image at a time to nnU-Net
```python
# predict a single numpy array (SimpleITKIO)
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
ret = predictor.predict_single_npy_array(img, props, None, None, False)
# predict a single numpy array (NibabelIO)
img, props = NibabelIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
ret = predictor.predict_single_npy_array(img, props, None, None, False)
# The following IS NOT RECOMMENDED. Use nnunetv2.imageio!
# nibabel, we need to transpose axes and spacing to match the training axes ordering for the nnU-Net default:
nib.load('Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')
img = np.asanyarray(img_nii.dataobj).transpose([2, 1, 0]) # reverse axis order to match SITK
props = {'spacing': img_nii.header.get_zooms()[::-1]} # reverse axis order to match SITK
ret = predictor.predict_single_npy_array(img, props, None, None, False)
```
## Predicting with a custom data iterator
tldr:
- highly flexible
- not for newbies
pros:
- you can do everything yourself
- you have all the freedom you want
- really fast if you remember to use multiprocessing in your iterator
cons:
- you need to do everything yourself
- harder than you might think
```python
img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
# each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!
# If 'ofile' is None, the result will be returned instead of written to a file
# the iterator is responsible for performing the correct preprocessing!
# note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!
# take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays
# (they both use predictor.predict_from_data_iterator) for inspiration!
def my_iterator(list_of_input_arrs, list_of_input_props):
preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
for a, p in zip(list_of_input_arrs, list_of_input_props):
data, seg = preprocessor.run_case_npy(a,
None,
p,
predictor.plans_manager,
predictor.configuration_manager,
predictor.dataset_json)
yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}
ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),
save_probabilities=False, num_processes_segmentation_export=3)
```
================================================
FILE: nnunetv2/inference/sliding_window_prediction.py
================================================
from functools import lru_cache
import numpy as np
import torch
from typing import Union, Tuple, List
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from scipy.ndimage import gaussian_filter
@lru_cache(maxsize=2)
def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,
value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \
-> torch.Tensor:
tmp = np.zeros(tile_size)
center_coords = [i // 2 for i in tile_size]
sigmas = [i * sigma_scale for i in tile_size]
tmp[tuple(center_coords)] = 1
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
gaussian_importance_map /= (torch.max(gaussian_importance_map) / value_scaling_factor)
gaussian_importance_map = gaussian_importance_map.to(device=device, dtype=dtype)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
mask = gaussian_importance_map == 0
gaussian_importance_map[mask] = torch.min(gaussian_importance_map[~mask])
return gaussian_importance_map
def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \
List[List[int]]:
assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size"
assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]
num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]
steps = []
for dim in range(len(tile_size)):
# the highest step value for this dimension is
max_step_value = image_size[dim] - tile_size[dim]
if num_steps[dim] > 1:
actual_step_size = max_step_value / (num_steps[dim] - 1)
else:
actual_step_size = 99999999999 # does not matter because there is only one step at 0
steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
steps.append(steps_here)
return steps
if __name__ == '__main__':
a = torch.rand((4, 2, 32, 23))
a_npy = a.numpy()
a_padded = pad_nd_image(a, new_shape=(48, 27))
a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27))
assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))])
assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))])
assert np.all(a_padded.numpy() == a_npy_padded)
================================================
FILE: nnunetv2/model_sharing/__init__.py
================================================
================================================
FILE: nnunetv2/model_sharing/entry_points.py
================================================
from nnunetv2.model_sharing.model_download import download_and_install_from_url
from nnunetv2.model_sharing.model_export import export_pretrained_model
from nnunetv2.model_sharing.model_import import install_model_from_zip_file
def print_license_warning():
print('')
print('######################################################')
print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
print('######################################################')
print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
"allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
"nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
print('######################################################')
print('')
def download_by_url():
import argparse
parser = argparse.ArgumentParser(
description="Use this to download pretrained models. This script is intended to download models via url only. "
"CAREFUL: This script will overwrite "
"existing models (if they share the same trainer class and plans as "
"the pretrained model.")
parser.add_argument("url", type=str, help='URL of the pretrained model')
args = parser.parse_args()
url = args.url
download_and_install_from_url(url)
def install_from_zip_entry_point():
import argparse
parser = argparse.ArgumentParser(
description="Use this to install a zip file containing a pretrained model.")
parser.add_argument("zip", type=str, help='zip file')
args = parser.parse_args()
zip = args.zip
install_model_from_zip_file(zip)
def export_pretrained_model_entry():
import argparse
parser = argparse.ArgumentParser(
description="Use this to export a trained model as a zip file.")
parser.add_argument('-d', type=str, required=True, help='Dataset name or id')
parser.add_argument('-o', type=str, required=True, help='Output file name')
parser.add_argument('-c', nargs='+', type=str, required=False,
default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'),
help="List of configuration names")
parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class')
parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier')
parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids')
parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ),
help='Lis tof checkpoint names to export. Default: checkpoint_final.pth')
parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations')
parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well')
args = parser.parse_args()
export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
export_crossval_predictions=args.exp_cv_preds)
================================================
FILE: nnunetv2/model_sharing/model_download.py
================================================
from typing import Optional
import requests
from batchgenerators.utilities.file_and_folder_operations import *
from time import time
from nnunetv2.model_sharing.model_import import install_model_from_zip_file
from nnunetv2.paths import nnUNet_results
from tqdm import tqdm
def download_and_install_from_url(url):
assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \
"set (RESULTS_FOLDER missing as environment variable, see " \
"Installation instructions)"
print('Downloading pretrained model from url:', url)
import http.client
http.client.HTTPConnection._http_vsn = 10
http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'
import os
home = os.path.expanduser('~')
random_number = int(time() * 1e7)
tempfile = join(home, f'.nnunetdownload_{str(random_number)}')
try:
download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)
print("Download finished. Extracting...")
install_model_from_zip_file(tempfile)
print("Done")
except Exception as e:
raise e
finally:
if isfile(tempfile):
os.remove(tempfile)
def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str:
# borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests
# NOTE the stream=True parameter below
with requests.get(url, stream=True, timeout=100) as r:
r.raise_for_status()
with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
return local_filename
================================================
FILE: nnunetv2/model_sharing/model_export.py
================================================
import zipfile
from nnunetv2.utilities.file_path_utilities import *
def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str,
configurations: Tuple[str] = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
trainer: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
folds: Tuple[int, ...] = (0, 1, 2, 3, 4),
strict: bool = True,
save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',),
export_crossval_predictions: bool = False) -> None:
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf:
for c in configurations:
print(f"Configuration {c}")
trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c)
if not isdir(trainer_output_dir):
if strict:
raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}")
else:
continue
expected_fold_folder = [f"fold_{i}" if i != 'all' else 'fold_all' for i in folds]
assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \
f"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}"
assert isfile(join(trainer_output_dir, "plans.json")), f"plans.json missing, {dataset_name} {c}"
for fold_folder in expected_fold_folder:
print(f"Exporting {fold_folder}")
# debug.json, does not exist yet
source_file = join(trainer_output_dir, fold_folder, "debug.json")
if isfile(source_file):
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# all requested checkpoints
for chk in save_checkpoints:
source_file = join(trainer_output_dir, fold_folder, chk)
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# progress.png
source_file = join(trainer_output_dir, fold_folder, "progress.png")
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# if it exists, network architecture.png
source_file = join(trainer_output_dir, fold_folder, "network_architecture.pdf")
if isfile(source_file):
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# validation folder with all predicted segmentations etc
if export_crossval_predictions:
source_folder = join(trainer_output_dir, fold_folder, "validation")
files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')]
for f in files:
zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results))
# just the summary.json file from the validation
else:
source_file = join(trainer_output_dir, fold_folder, "validation", "summary.json")
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
if isdir(source_folder):
if export_crossval_predictions:
source_files = subfiles(source_folder, join=True)
else:
source_files = [
join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in
['summary.json', 'postprocessing.pkl', 'postprocessing.json']
]
for s in source_files:
if isfile(s):
zipf.write(s, os.path.relpath(s, nnUNet_results))
# plans
source_file = join(trainer_output_dir, "plans.json")
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# fingerprint
source_file = join(trainer_output_dir, "dataset_fingerprint.json")
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
# dataset
source_file = join(trainer_output_dir, "dataset.json")
zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles')
if not isdir(ensemble_dir):
print("No ensemble directory found for task", dataset_name_or_id)
return
subd = subdirs(ensemble_dir, join=False)
# figure out whether the models in the ensemble are all within the exported models here
for ens in subd:
identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens)
ok = True
for i in identifiers:
tr, pl, c = convert_identifier_to_trainer_plans_config(i)
if tr == trainer and pl == plans_identifier and c in configurations:
pass
else:
ok = False
if ok:
print(f'found matching ensemble: {ens}')
source_folder = join(ensemble_dir, ens)
if export_crossval_predictions:
source_files = subfiles(source_folder, join=True)
else:
source_files = [
join(source_folder, i) for i in
['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i))
]
for s in source_files:
zipf.write(s, os.path.relpath(s, nnUNet_results))
inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json')
if isfile(inference_information_file):
zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results))
inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt')
if isfile(inference_information_txt_file):
zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results))
print('Done')
if __name__ == '__main__':
export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, ))
================================================
FILE: nnunetv2/model_sharing/model_import.py
================================================
import zipfile
from nnunetv2.paths import nnUNet_results
def install_model_from_zip_file(zip_file: str):
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(nnUNet_results)
================================================
FILE: nnunetv2/paths.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
"""
PLEASE READ documentation/setting_up_paths.md FOR INFORMATION TO HOW TO SET THIS UP
"""
nnUNet_raw = os.environ.get('nnUNet_raw')
nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
nnUNet_results = os.environ.get('nnUNet_results')
if nnUNet_raw is None:
print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files "
"are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like "
"this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set "
"this up properly.")
if nnUNet_preprocessed is None:
print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing "
"or training. If this is not intended, please read documentation/setting_up_paths.md for information on how "
"to set this up.")
if nnUNet_results is None:
print("nnUNet_results is not defined and nnU-Net cannot be used for training or "
"inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information "
"on how to set this up.")
================================================
FILE: nnunetv2/postprocessing/__init__.py
================================================
================================================
FILE: nnunetv2/postprocessing/remove_connected_components.py
================================================
import argparse
import multiprocessing
import shutil
from typing import Union, Tuple, List, Callable
import numpy as np
from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component
from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \
isdir, save_pickle, load_pickle, save_json
from nnunetv2.configuration import default_num_processes
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \
load_summary_json, label_or_region_to_key
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray,
labels_or_regions: Union[int, Tuple[int, ...],
List[Union[int, Tuple[int, ...]]]],
background_label: int = 0) -> np.ndarray:
mask = np.zeros_like(segmentation, dtype=bool)
if not isinstance(labels_or_regions, list):
labels_or_regions = [labels_or_regions]
for l_or_r in labels_or_regions:
mask |= region_or_label_to_mask(segmentation, l_or_r)
mask_keep = remove_all_but_largest_component(mask)
ret = np.copy(segmentation) # do not modify the input!
ret[mask & ~mask_keep] = background_label
return ret
def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]):
for fn, kwargs in zip(pp_fns, pp_fn_kwargs):
segmentation = fn(segmentation, **kwargs)
return segmentation
def load_postprocess_save(segmentation_file: str,
output_fname: str,
image_reader_writer: BaseReaderWriter,
pp_fns: List[Callable],
pp_fn_kwargs: List[dict]):
seg, props = image_reader_writer.read_seg(segmentation_file)
seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs)
image_reader_writer.write_seg(seg, output_fname, props)
def determine_postprocessing(folder_predictions: str,
folder_ref: str,
plans_file_or_dict: Union[str, dict],
dataset_json_file_or_dict: Union[str, dict],
num_processes: int = default_num_processes,
keep_postprocessed_files: bool = True):
"""
Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be
used with apply_postprocessing_to_folder.
Postprocessed files are saved in folder_predictions/postprocessed. Set
keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created
and deleted regardless).
If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
"""
output_folder = join(folder_predictions, 'postprocessed')
if plans_file_or_dict is None:
expected_plans_file = join(folder_predictions, 'plans.json')
if not isfile(expected_plans_file):
raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been "
f"created while running nnUNetv2_predict. Sadge.")
plans_file_or_dict = load_json(expected_plans_file)
plans_manager = PlansManager(plans_file_or_dict)
if dataset_json_file_or_dict is None:
expected_dataset_json_file = join(folder_predictions, 'dataset.json')
if not isfile(expected_dataset_json_file):
raise RuntimeError(
f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been "
f"created while running nnUNetv2_predict. Sadge.")
dataset_json_file_or_dict = load_json(expected_dataset_json_file)
if not isinstance(dataset_json_file_or_dict, dict):
dataset_json = load_json(dataset_json_file_or_dict)
else:
dataset_json = dataset_json_file_or_dict
rw = plans_manager.image_reader_writer_class()
label_manager = plans_manager.get_label_manager(dataset_json)
labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels
predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False)
ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False)
# we should print a warning if not all files from folder_ref are present in folder_predictions
if not all([i in predicted_files for i in ref_files]):
print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing '
f'should always be done on the entire dataset!')
# before we start we should evaluate the imaegs in the source folder
if not isfile(join(folder_predictions, 'summary.json')):
compute_metrics_on_folder(folder_ref,
folder_predictions,
join(folder_predictions, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
# we save the postprocessing functions in here
pp_fns = []
pp_fn_kwargs = []
# pool party!
with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
# now let's see whether removing all but the largest foreground region improves the scores
output_here = join(output_folder, 'temp', 'keep_largest_fg')
maybe_mkdir_p(output_here)
pp_fn = remove_all_but_largest_component_from_segmentation
kwargs = {
'labels_or_regions': label_manager.foreground_labels,
}
pool.starmap(
load_postprocess_save,
zip(
[join(folder_predictions, i) for i in predicted_files],
[join(output_here, i) for i in predicted_files],
[rw] * len(predicted_files),
[[pp_fn]] * len(predicted_files),
[[kwargs]] * len(predicted_files)
)
)
compute_metrics_on_folder(folder_ref,
output_here,
join(output_here, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
# now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
# that if a single class got worse as a result we won't do this. We can change this in the future but right now I
# prefer to do it this way
baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
pp_results = load_summary_json(join(output_here, 'summary.json'))
do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice']
if do_this:
for class_id in pp_results['mean'].keys():
if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']:
do_this = False
break
if do_this:
print(f'Results were improved by removing all but the largest foreground region. '
f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} '
f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}')
source = output_here
pp_fns.append(pp_fn)
pp_fn_kwargs.append(kwargs)
else:
print(f'Removing all but the largest foreground region did not improve results!')
source = folder_predictions
# in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and
# then evaluate for each class whether this improved results. This is no longer possible because we now support
# region-based predictions and regions can overlap, causing interactions
# in principle the order with which the postprocessing is applied to the regions matter as well and should be
# investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think)
# to stick to the order in which they are declared in dataset.json (if you want to think about it then think about
# region_class_order)
# 2023_02_06: I hate myself for the comment above. Thanks past me
if len(labels_or_regions) > 1:
for label_or_region in labels_or_regions:
pp_fn = remove_all_but_largest_component_from_segmentation
kwargs = {
'labels_or_regions': label_or_region,
}
output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion')
maybe_mkdir_p(output_here)
pool.starmap(
load_postprocess_save,
zip(
[join(source, i) for i in predicted_files],
[join(output_here, i) for i in predicted_files],
[rw] * len(predicted_files),
[[pp_fn]] * len(predicted_files),
[[kwargs]] * len(predicted_files)
)
)
compute_metrics_on_folder(folder_ref,
output_here,
join(output_here, 'summary.json'),
rw,
dataset_json['file_ending'],
labels_or_regions,
label_manager.ignore_label,
num_processes)
baseline_results = load_summary_json(join(source, 'summary.json'))
pp_results = load_summary_json(join(output_here, 'summary.json'))
do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice']
if do_this:
print(f'Results were improved by removing all but the largest component for {label_or_region}. '
f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')):
shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'))
shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), )
source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')
pp_fns.append(pp_fn)
pp_fn_kwargs.append(kwargs)
else:
print(f'Removing all but the largest component for {label_or_region} did not improve results! '
f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
[shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)]
save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl'))
baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
final_results = load_summary_json(join(output_folder, 'summary.json'))
tmp = {
'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']},
'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']},
'postprocessing_fns': [i.__name__ for i in pp_fns],
'postprocessing_kwargs': pp_fn_kwargs,
}
# json is very annoying. Can't handle tuples as dict keys.
tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in
tmp['input_folder']['mean'].keys()}
tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in
tmp['postprocessed']['mean'].keys()}
# did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable"
recursive_fix_for_json_export(tmp)
save_json(tmp, join(folder_predictions, 'postprocessing.json'))
shutil.rmtree(join(output_folder, 'temp'))
if not keep_postprocessed_files:
shutil.rmtree(output_folder)
return pp_fns, pp_fn_kwargs
def apply_postprocessing_to_folder(input_folder: str,
output_folder: str,
pp_fns: List[Callable],
pp_fn_kwargs: List[dict],
plans_file_or_dict: Union[str, dict] = None,
dataset_json_file_or_dict: Union[str, dict] = None,
num_processes=8) -> None:
"""
If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
"""
if plans_file_or_dict is None:
expected_plans_file = join(input_folder, 'plans.json')
if not isfile(expected_plans_file):
raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been "
f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply "
f"postprocessing to was create from an ensemble then just specify one of the "
f"plans files of the ensemble members in plans_file_or_dict")
plans_file_or_dict = load_json(expected_plans_file)
plans_manager = PlansManager(plans_file_or_dict)
if dataset_json_file_or_dict is None:
expected_dataset_json_file = join(input_folder, 'dataset.json')
if not isfile(expected_dataset_json_file):
raise RuntimeError(
f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been "
f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.")
dataset_json_file_or_dict = load_json(expected_dataset_json_file)
if not isinstance(dataset_json_file_or_dict, dict):
dataset_json = load_json(dataset_json_file_or_dict)
else:
dataset_json = dataset_json_file_or_dict
rw = plans_manager.image_reader_writer_class()
maybe_mkdir_p(output_folder)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False)
_ = p.starmap(load_postprocess_save,
zip(
[join(input_folder, i) for i in files],
[join(output_folder, i) for i in files],
[rw] * len(files),
[pp_fns] * len(files),
[pp_fn_kwargs] * len(files)
)
)
def entry_point_determine_postprocessing_folder():
parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.')
parser.add_argument('-i', type=str, required=True, help='Input folder')
parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels')
parser.add_argument('-plans_json', type=str, required=False, default=None,
help="plans file to use. If not specified we will look for the plans.json file in the "
"input folder (input_folder/plans.json)")
parser.add_argument('-dataset_json', type=str, required=False, default=None,
help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
"input folder (input_folder/dataset.json)")
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f"number of processes to use. Default: {default_num_processes}")
parser.add_argument('--remove_postprocessed', action='store_true', required=False,
help='set this is you don\'t want to keep the postprocessed files')
args = parser.parse_args()
determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np,
not args.remove_postprocessed)
def entry_point_apply_postprocessing():
parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.')
parser.add_argument('-i', type=str, required=True, help='Input folder')
parser.add_argument('-o', type=str, required=True, help='Output folder')
parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file')
parser.add_argument('-np', type=int, required=False, default=default_num_processes,
help=f"number of processes to use. Default: {default_num_processes}")
parser.add_argument('-plans_json', type=str, required=False, default=None,
help="plans file to use. If not specified we will look for the plans.json file in the "
"input folder (input_folder/plans.json)")
parser.add_argument('-dataset_json', type=str, required=False, default=None,
help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
"input folder (input_folder/dataset.json)")
args = parser.parse_args()
pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file)
apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np)
if __name__ == '__main__':
trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres'
labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr')
plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
folds = (0, 1, 2, 3, 4)
label_manager = plans_manager.get_label_manager(dataset_json)
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False)
fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans,
dataset_json, 8, keep_postprocessed_files=True)
save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl'))
fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl'))
apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs,
plans_manager.plans, dataset_json,
8)
compute_metrics_on_folder(labelstr,
merged_output_folder + '_pp',
join(merged_output_folder + '_pp', 'summary.json'),
plans_manager.image_reader_writer_class(),
dataset_json['file_ending'],
label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels,
label_manager.ignore_label,
8)
================================================
FILE: nnunetv2/preprocessing/__init__.py
================================================
================================================
FILE: nnunetv2/preprocessing/cropping/__init__.py
================================================
================================================
FILE: nnunetv2/preprocessing/cropping/cropping.py
================================================
import numpy as np
from scipy.ndimage import binary_fill_holes
from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice
def create_nonzero_mask(data):
"""
:param data:
:return: the mask is True where the data is nonzero
"""
assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = data[0] != 0
for c in range(1, data.shape[0]):
nonzero_mask |= data[c] != 0
return binary_fill_holes(nonzero_mask)
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
:param data:
:param seg:
:param nonzero_label: this will be written into the segmentation map
:return:
"""
nonzero_mask = create_nonzero_mask(data)
bbox = get_bbox_from_mask(nonzero_mask)
slicer = bounding_box_to_slice(bbox)
nonzero_mask = nonzero_mask[slicer][None]
slicer = (slice(None), ) + slicer
data = data[slicer]
if seg is not None:
seg = seg[slicer]
seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
else:
seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
return data, seg, bbox
================================================
FILE: nnunetv2/preprocessing/normalization/__init__.py
================================================
================================================
FILE: nnunetv2/preprocessing/normalization/default_normalization_schemes.py
================================================
from abc import ABC, abstractmethod
from typing import Type
import numpy as np
from numpy import number
class ImageNormalization(ABC):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None
def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None,
target_dtype: Type[number] = np.float32):
assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
self.use_mask_for_norm = use_mask_for_norm
assert isinstance(intensityproperties, dict)
self.intensityproperties = intensityproperties
self.target_dtype = target_dtype
@abstractmethod
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
"""
Image and seg must have the same shape. Seg is not always used
"""
pass
class ZScoreNormalization(ImageNormalization):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
"""
here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by
default.
"""
eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4
image = image.astype(self.target_dtype, copy=False)
if self.use_mask_for_norm:
assert seg is not None, ("use_mask_for_norm is set, please provide a mask for the nonzero areas of the "
"image via seg. The mask will be computed as `mask = seg >= 0`. You can use "
"create_nonzero_mask from nnunetv2/preprocessing/cropping")
if seg is not None and self.use_mask_for_norm:
# negative values in the segmentation encode the 'outside' region (think zero values around the brain as
# in BraTS). We want to run the normalization only in the brain region, so we need to mask the image.
# The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially
# reduced the image size.
mask = seg >= 0
mean = image[mask].mean()
std = image[mask].std()
image[mask] = (image[mask] - mean) / (max(std, eps))
else:
mean = image.mean()
std = image.std()
image -= mean
image /= (max(std, eps))
return image
class CTNormalization(ImageNormalization):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4
mean_intensity = self.intensityproperties['mean']
std_intensity = self.intensityproperties['std']
lower_bound = self.intensityproperties['percentile_00_5']
upper_bound = self.intensityproperties['percentile_99_5']
image = image.astype(self.target_dtype, copy=False)
np.clip(image, lower_bound, upper_bound, out=image)
image -= mean_intensity
image /= max(std_intensity, eps)
return image
class NoNormalization(ImageNormalization):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
return image.astype(self.target_dtype, copy=False)
class RescaleTo01Normalization(ImageNormalization):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
eps = 1e-8 if not self.target_dtype == np.float16 else 1e-4
image = image.astype(self.target_dtype, copy=False)
image -= image.min()
image /= np.clip(image.max(), a_min=eps, a_max=None)
return image
class RGBTo01Normalization(ImageNormalization):
leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \
"Your images do not seem to be RGB images"
assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \
". Your images do not seem to be RGB images"
image = image.astype(self.target_dtype, copy=False)
image /= 255.
return image
================================================
FILE: nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py
================================================
from typing import Type
from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \
ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization
channel_name_to_normalization_mapping = {
'ct': CTNormalization,
'nonorm': NoNormalization,
'zscore': ZScoreNormalization,
'rescale_to_0_1': RescaleTo01Normalization,
'rgb_to_0_1': RGBTo01Normalization
}
def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]:
"""
If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is
not found, use the default (ZScoreNormalization)
"""
norm_scheme = channel_name_to_normalization_mapping.get(channel_name.casefold())
if norm_scheme is None:
norm_scheme = ZScoreNormalization
# print('Using %s for image normalization' % norm_scheme.__name__)
return norm_scheme
================================================
FILE: nnunetv2/preprocessing/normalization/readme.md
================================================
The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different
then you can just
- create a new subclass of ImageNormalization
- map your custom channel identifier to that subclass in channel_name_to_normalization_mapping
- run plan and preprocess again with your custom normlaization scheme
================================================
FILE: nnunetv2/preprocessing/preprocessors/__init__.py
================================================
================================================
FILE: nnunetv2/preprocessing/preprocessors/default_preprocessor.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import shutil
from time import sleep
from typing import Tuple
from typing import Union
import SimpleITK
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from tqdm import tqdm
import nnunetv2
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
class DefaultPreprocessor(object):
def __init__(self, verbose: bool = True):
self.verbose = verbose
self.show_progress_bar = True
"""
Everything we need is in the plans. Those are given when run() is called
"""
def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict,
plans_manager: PlansManager, configuration_manager: ConfigurationManager,
dataset_json: Union[dict, str]):
# let's not mess up the inputs!
data = data.astype(np.float32) # this creates a copy
if seg is not None:
assert data.shape[1:] == seg.shape[1:], "Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct"
seg = np.copy(seg)
has_seg = seg is not None
# apply transpose_forward, this also needs to be applied to the spacing!
data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
if seg is not None:
seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward]
# crop, remember to store size before cropping!
shape_before_cropping = data.shape[1:]
properties['shape_before_cropping'] = shape_before_cropping
# this command will generate a segmentation. This is important because of the nonzero mask which we may need
data, seg, bbox = crop_to_nonzero(data, seg)
properties['bbox_used_for_cropping'] = bbox
# print(data.shape, seg.shape)
properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]
# resample
target_spacing = configuration_manager.spacing # this should already be transposed
if len(target_spacing) < len(data.shape[1:]):
# target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d
# in 2d configuration we do not change the spacing between slices
target_spacing = [original_spacing[0]] + target_spacing
new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
# normalize
# normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no
# longer fitting the images perfectly!
data = self._normalize(data, seg, configuration_manager,
plans_manager.foreground_intensity_properties_per_channel)
# print('current shape', data.shape[1:], 'current_spacing', original_spacing,
# '\ntarget shape', new_shape, 'target_spacing', target_spacing)
old_shape = data.shape[1:]
data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)
seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)
if self.verbose:
print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, '
f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}')
# if we have a segmentation, sample foreground locations for oversampling and add those to properties
if has_seg:
# reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument
# with a LabelManager Instance in this function because that's all its used for. Dunno what's better.
# LabelManager is pretty light computation-wise.
label_manager = plans_manager.get_label_manager(dataset_json)
collect_for_this = label_manager.foreground_regions if label_manager.has_regions \
else label_manager.foreground_labels
# when using the ignore label we want to sample only from annotated regions. Therefore we also need to
# collect samples uniformly from all classes (incl background)
if label_manager.has_ignore_label:
collect_for_this.append([-1] + label_manager.all_labels)
# no need to filter background in regions because it is already filtered in handle_labels
# print(all_labels, regions)
properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this,
verbose=self.verbose)
seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)
if np.max(seg) > 127:
seg = seg.astype(np.int16)
else:
seg = seg.astype(np.int8)
return data, seg, properties
def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
dataset_json: Union[dict, str]):
"""
seg file can be none (test cases)
order of operations is: transpose -> crop -> resample
so when we export we need to run the following order: resample -> crop -> transpose (we could also run
transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner)
"""
if isinstance(dataset_json, str):
dataset_json = load_json(dataset_json)
rw = plans_manager.image_reader_writer_class()
# load image(s)
data, data_properties = rw.read_images(image_files)
# if possible, load seg
if seg_file is not None:
seg, _ = rw.read_seg(seg_file)
else:
seg = None
if self.verbose:
print(seg_file)
data, seg, data_properties = self.run_case_npy(data, seg, data_properties, plans_manager, configuration_manager,
dataset_json)
return data, seg, data_properties
def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str,
plans_manager: PlansManager, configuration_manager: ConfigurationManager,
dataset_json: Union[dict, str]):
data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json)
data = data.astype(np.float32, copy=False)
seg = seg.astype(np.int16, copy=False)
# print('dtypes', data.dtype, seg.dtype)
block_size_data, chunk_size_data = nnUNetDatasetBlosc2.comp_blosc2_params(
data.shape,
tuple(configuration_manager.patch_size),
data.itemsize)
block_size_seg, chunk_size_seg = nnUNetDatasetBlosc2.comp_blosc2_params(
seg.shape,
tuple(configuration_manager.patch_size),
seg.itemsize)
nnUNetDatasetBlosc2.save_case(data, seg, properties, output_filename_truncated,
chunks=chunk_size_data, blocks=block_size_data,
chunks_seg=chunk_size_seg, blocks_seg=block_size_seg)
@staticmethod
def _sample_foreground_locations(
seg: np.ndarray,
classes_or_regions: Union[List[int], List[Tuple[int, ...]]],
seed: int = 1234,
verbose: bool = False,
min_num_samples=10000,
min_percent_coverage = 0.01
):
rndst = np.random.RandomState(seed)
class_locs = {}
# Normalize requested labels and compute the set of all labels we might need
normalized = []
requested_labels = set()
for c in classes_or_regions:
if isinstance(c, (tuple, list)):
labs = tuple(int(x) for x in c)
normalized.append(labs)
requested_labels.update(labs)
else:
lab = int(c)
normalized.append(lab)
requested_labels.add(lab)
# Create mask for all requested labels (this includes 0 if requested)
requested_labels_arr = np.fromiter(requested_labels, dtype=np.int32)
valid_mask = np.isin(seg, requested_labels_arr)
coords = np.argwhere(valid_mask)
seg_sel = seg[valid_mask]
del valid_mask
n = seg_sel.size
if n == 0:
for c in classes_or_regions:
k = tuple(c) if isinstance(c, (tuple, list)) else int(c)
class_locs[k] = []
return class_locs
# sort once, then compute label blocks
order = np.argsort(seg_sel, kind="stable")
lab_sorted = seg_sel[order]
coords_sorted = coords[order]
change = np.flatnonzero(lab_sorted[1:] != lab_sorted[:-1]) + 1
starts = np.r_[0, change]
ends = np.r_[change, n]
labels_present = lab_sorted[starts]
label_to_range = {int(l): (int(s), int(e)) for l, s, e in zip(labels_present, starts, ends)}
present_labels = set(label_to_range.keys())
for c in classes_or_regions:
is_region = isinstance(c, (tuple, list))
labs = tuple(int(x) for x in c) if is_region else (int(c),)
k = labs if is_region else labs[0]
# Skip if none of the labels are present
if not any(lab in present_labels for lab in labs):
class_locs[k] = []
continue
# Collect ranges for present labels in this class/region
ranges = []
counts = []
for lab in labs:
r = label_to_range.get(lab)
if r is None:
continue
s, e = r
cnt = e - s
if cnt > 0:
ranges.append((s, e))
counts.append(cnt)
if len(counts) == 0:
class_locs[k] = []
continue
total = int(np.sum(counts))
target_num_samples = min(min_num_samples, total)
target_num_samples = max(target_num_samples, int(np.ceil(total * min_percent_coverage)))
# Sample uniformly without replacement from the union of ranges, without building an n-sized mask
# Draw target_num_samples unique offsets in [0, total)
offsets = rndst.choice(total, target_num_samples, replace=False)
# Map offsets -> (range index, in-range offset) using cumulative counts
cum = np.cumsum(counts)
which = np.searchsorted(cum, offsets, side="right")
prev = np.concatenate(([0], cum[:-1]))
in_range = offsets - prev[which]
# Convert to indices in coords_sorted
starts_for_pick = np.fromiter((ranges[i][0] for i in which), dtype=np.int64, count=which.size)
picked_idx = starts_for_pick + in_range.astype(np.int64)
selected = coords_sorted[picked_idx]
class_locs[k] = selected
if verbose:
print(c, target_num_samples)
return class_locs
# @staticmethod
# def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]],
# seed: int = 1234, verbose: bool = False):
# num_samples = 10000
# min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too
# # sparse
# rndst = np.random.RandomState(seed)
# class_locs = {}
# foreground_mask = seg != 0
# foreground_coords = np.argwhere(foreground_mask)
# seg = seg[foreground_mask]
# del foreground_mask
# unique_labels = pd.unique(seg.ravel())
#
# # We don't need more than 1e7 foreground samples. That's insanity. Cap here
# if len(foreground_coords) > 1e7:
# take_every = math.floor(len(foreground_coords) / 1e7)
# # keep computation time reasonable
# if verbose:
# print(f'Subsampling foreground pixels 1:{take_every} for computational reasons')
# foreground_coords = foreground_coords[::take_every]
# seg = seg[::take_every]
#
# for c in classes_or_regions:
# k = c if not isinstance(c, list) else tuple(c)
#
# # check if any of the labels are in seg, if not skip c
# if isinstance(c, (tuple, list)):
# if not any([ci in unique_labels for ci in c]):
# class_locs[k] = []
# continue
# else:
# if c not in unique_labels:
# class_locs[k] = []
# continue
#
# if isinstance(c, (tuple, list)):
# mask = seg == c[0]
# for cc in c[1:]:
# mask = mask | (seg == cc)
# all_locs = foreground_coords[mask]
# else:
# mask = seg == c
# all_locs = foreground_coords[mask]
# if len(all_locs) == 0:
# class_locs[k] = []
# continue
# target_num_samples = min(num_samples, len(all_locs))
# target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
#
# selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
# class_locs[k] = selected
# if verbose:
# print(c, target_num_samples)
# seg = seg[~mask]
# foreground_coords = foreground_coords[~mask]
# return class_locs
def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager,
foreground_intensity_properties_per_channel: dict) -> np.ndarray:
for c in range(data.shape[0]):
scheme = configuration_manager.normalization_schemes[c]
normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"),
scheme,
'nnunetv2.preprocessing.normalization')
if normalizer_class is None:
raise RuntimeError(f'Unable to locate class \'{scheme}\' for normalization')
normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c],
intensityproperties=foreground_intensity_properties_per_channel[str(c)])
data[c] = normalizer.run(data[c], seg[0])
return data
def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str,
num_processes: int):
"""
data identifier = configuration name in plans. EZ.
"""
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw"
plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')
assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \
"first." % plans_file
plans = load_json(plans_file)
plans_manager = PlansManager(plans)
configuration_manager = plans_manager.get_configuration(configuration_name)
dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json')
dataset_json = load_json(dataset_json_file)
output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier)
if isdir(output_directory):
shutil.rmtree(output_directory)
maybe_mkdir_p(output_directory)
dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)
# identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames]
# output_filenames_truncated = [join(output_directory, i) for i in identifiers]
# multiprocessing magic.
r = []
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
remaining = list(range(len(dataset)))
# p is pretty nifti. If we kill workers they just respawn but don't do any work.
# So we need to store the original pool of workers.
workers = [j for j in p._pool]
for k in dataset.keys():
r.append(p.starmap_async(self.run_case_save,
((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],
plans_manager, configuration_manager,
dataset_json),)))
with tqdm(desc="Preprocessing cases", total=len(dataset),
disable=not getattr(self, 'show_progress_bar', True)) as pbar:
while len(remaining) > 0:
all_alive = all([j.is_alive() for j in workers])
if not all_alive:
raise RuntimeError('Some background worker is 6 feet under. Yuck. \n'
'OK jokes aside.\n'
'One of your background processes is missing. This could be because of '
'an error (look for an error message) or because it was killed '
'by your OS due to running out of RAM. If you don\'t see '
'an error message, out of RAM is likely the problem. In that case '
'reducing the number of workers might help')
done = [i for i in remaining if r[i].ready()]
for i in done:
r[i].get() # trigger any errors from worker (single call, no duplicate)
pbar.update()
remaining = [i for i in remaining if i not in done]
sleep(0.1)
def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
configuration_manager: ConfigurationManager) -> np.ndarray:
# this function will be called at the end of self.run_case. Can be used to change the segmentation
# after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling
# and don't have to create a new dataset each time I modify my experiments
return seg
def example_test_case_preprocessing():
# (paths to files may need adaptations)
plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json'
dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json'
input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz']
configuration = '3d_fullres'
pp = DefaultPreprocessor()
# _ because this position would be the segmentation if seg_file was not None (training case)
# even if you have the segmentation, don't put the file there! You should always evaluate in the original
# resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as
# specified by plans)
plans_manager = PlansManager(plans_file)
data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager,
configuration_manager=plans_manager.get_configuration(configuration),
dataset_json=dataset_json_file)
# voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO)
return data
def _verify_class_locations(shape, outfile, class_locs):
import numpy as np
import SimpleITK as sitk
out = np.zeros(shape, dtype=np.uint16) # allow many labels safely
for i, k in enumerate(class_locs.keys()):
class_coords = class_locs[k][:, 1:]
if class_coords is None:
continue
class_coords = np.asarray(class_coords)
if class_coords.size == 0:
continue
# Expect coords in (N, 3) as (z, y, x)
if class_coords.ndim != 2 or class_coords.shape[1] != 3:
raise ValueError(f"class_locs[{k}] must have shape (N, 3), got {class_coords.shape}")
z = class_coords[:, 0].astype(np.int64)
y = class_coords[:, 1].astype(np.int64)
x = class_coords[:, 2].astype(np.int64)
# Optional bounds check (cheap and prevents hard-to-debug indexing errors)
if (z.min() < 0 or y.min() < 0 or x.min() < 0 or
z.max() >= shape[0] or y.max() >= shape[1] or x.max() >= shape[2]):
raise ValueError(f"Coordinates for {k} are out of bounds for shape={shape}")
out[z, y, x] = i + 1 # label 1..K
img = sitk.GetImageFromArray(out) # SimpleITK assumes array is z,y,x
sitk.WriteImage(img, outfile)
if __name__ == '__main__':
# example_test_case_preprocessing()
# pp = DefaultPreprocessor()
# pp.run(2, '2d', 'nnUNetPlans', 8)
###########################################################################################################
# how to process a test cases? This is an example:
# example_test_case_preprocessing()
seg = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage('/home/isensee/temp/H-mito-val-v2.nii.gz'))[None]
a = DefaultPreprocessor._sample_foreground_locations(seg, np.arange(1, np.max(seg) + 1), min_percent_coverage=0.50)
_verify_class_locations(seg.shape[1:], '/home/isensee/temp/deleteme.nii.gz', a)
================================================
FILE: nnunetv2/preprocessing/resampling/__init__.py
================================================
================================================
FILE: nnunetv2/preprocessing/resampling/default_resampling.py
================================================
from collections import OrderedDict
from copy import deepcopy
from typing import Union, Tuple, List
import numpy as np
import pandas as pd
import torch
from batchgenerators.augmentations.utils import resize_segmentation
from scipy.ndimage import map_coordinates
from skimage.transform import resize
from nnunetv2.configuration import ANISO_THRESHOLD
def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
return do_separate_z
def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
return axis
def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
assert len(old_spacing) == len(old_shape)
assert len(old_shape) == len(new_spacing)
new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
return new_shape
def determine_do_sep_z_and_axis(
force_separate_z: bool,
current_spacing,
new_spacing,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
axis = get_lowres_axis(current_spacing)
else:
axis = None
else:
if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(current_spacing)
elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(new_spacing)
else:
do_separate_z = False
axis = None
if axis is not None:
if len(axis) == 3:
do_separate_z = False
axis = None
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
axis = None
else:
axis = axis[0]
return do_separate_z, axis
def resample_data_or_seg_to_spacing(data: np.ndarray,
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)
if data is not None:
assert data.ndim == 4, "data must be c x y z"
shape = np.array(data.shape)
new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
return data_reshaped
def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
"""
needed for segmentation export. Stupid, I know
"""
if isinstance(data, torch.Tensor):
data = data.numpy()
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)
if data is not None:
assert data.ndim == 4, "data must be c x y z"
data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
return data_reshaped
def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
do_separate_z: bool = False, order_z: int = 0, dtype_out = None):
"""
separate_z=True will resample with order 0 along z
:param data:
:param new_shape:
:param is_seg:
:param axis:
:param order:
:param do_separate_z:
:param order_z: only applies if do_separate_z is True
:return:
"""
assert data.ndim == 4, "data must be (c, x, y, z)"
assert len(new_shape) == data.ndim - 1
if is_seg:
resize_fn = resize_segmentation
kwargs = OrderedDict()
else:
resize_fn = resize
kwargs = {'mode': 'edge', 'anti_aliasing': False}
shape = np.array(data[0].shape)
new_shape = np.array(new_shape)
if dtype_out is None:
dtype_out = data.dtype
reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
if np.any(shape != new_shape):
data = data.astype(float, copy=False)
if do_separate_z:
assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
if axis == 0:
new_shape_2d = new_shape[1:]
elif axis == 1:
new_shape_2d = new_shape[[0, 2]]
else:
new_shape_2d = new_shape[:-1]
for c in range(data.shape[0]):
tmp = deepcopy(new_shape)
tmp[axis] = shape[axis]
reshaped_here = np.zeros(tmp)
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
elif axis == 1:
reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
else:
reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
if shape[axis] != new_shape[axis]:
# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_here.shape
# align_corners=False
row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
dim_scale = float(orig_dim) / dim
map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
map_rows = row_scale * (map_rows + 0.5) - 0.5
map_cols = col_scale * (map_cols + 0.5) - 0.5
map_dims = dim_scale * (map_dims + 0.5) - 0.5
coord_map = np.array([map_rows, map_cols, map_dims])
if not is_seg or order_z == 0:
reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]
else:
unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data)
for i, cl in enumerate(unique_labels):
reshaped_final[c][np.round(
map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
mode='nearest')) > 0.5] = cl
else:
reshaped_final[c] = reshaped_here
else:
for c in range(data.shape[0]):
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
return reshaped_final
else:
# print("no resampling necessary")
return data
if __name__ == '__main__':
input_array = np.random.random((1, 42, 231, 142))
output_shape = (52, 256, 256)
out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True)
print(out.shape, input_array.shape)
================================================
FILE: nnunetv2/preprocessing/resampling/no_resampling.py
================================================
from typing import Union, Tuple, List
import numpy as np
import torch
def no_resampling_hack(
data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]
):
return data
================================================
FILE: nnunetv2/preprocessing/resampling/resample_torch.py
================================================
from copy import deepcopy
from typing import Union, Tuple, List
import numpy as np
import torch
from einops import rearrange
from torch.nn import functional as F
from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from nnunetv2.preprocessing.resampling.default_resampling import determine_do_sep_z_and_axis
def resample_torch_simple(
data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
is_seg: bool = False,
num_threads: int = 4,
device: torch.device = torch.device('cpu'),
memefficient_seg_resampling: bool = False,
mode='linear'
):
if mode == 'linear':
if data.ndim == 4:
torch_mode = 'trilinear'
elif data.ndim == 3:
torch_mode = 'bilinear'
else:
raise RuntimeError
else:
torch_mode = mode
if isinstance(new_shape, np.ndarray):
new_shape = [int(i) for i in new_shape]
if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
return data
else:
n_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
new_shape = tuple(new_shape)
with torch.no_grad():
input_was_numpy = isinstance(data, np.ndarray)
if input_was_numpy:
data = torch.from_numpy(data).to(device)
else:
orig_device = deepcopy(data.device)
data = data.to(device)
if is_seg:
unique_values = torch.unique(data)
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
if not memefficient_seg_resampling:
# believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU)
# Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the
# uncertain ones be determined by argmax
# unique_values = torch.unique(data)
# result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16)
# for i, u in enumerate(unique_values):
# result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0]
# result = unique_values[result.argmax(0)]
result_tmp = torch.empty((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
device=device)
scale_factor = 1000
done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
for i, u in enumerate(unique_values):
result_tmp[i] = \
F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
antialias=False)[0]
mask = result_tmp[i] > (0.7 * scale_factor)
result[mask] = u.item()
done_mask |= mask
if not torch.all(done_mask):
# print('resolving argmax', torch.sum(~done_mask), "voxels to go")
result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
else:
for i, u in enumerate(unique_values):
if u == 0:
pass
result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[
0] > 0.5] = u
else:
result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
if input_was_numpy:
result = result.cpu().numpy()
else:
result = result.to(orig_device)
torch.set_num_threads(n_threads)
return result
def resample_torch_fornnunet(
data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
num_threads: int = 4,
device: torch.device = torch.device('cpu'),
memefficient_seg_resampling: bool = False,
force_separate_z: Union[bool, None] = None,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
mode='linear',
aniso_axis_mode='nearest-exact'
):
"""
data must be c, x, y, z
"""
assert data.ndim == 4, "data must be c, x, y, z"
new_shape = [int(i) for i in new_shape]
orig_shape = data.shape
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)
if not isinstance(axis, (tuple, list)):
axis = (axis,)
# print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis)
if do_separate_z:
was_numpy = isinstance(data, np.ndarray)
if was_numpy:
data = torch.from_numpy(data)
assert len(axis) == 1
axis = axis[0]
tmp = "xyz"
axis_letter = tmp[axis]
others_int = [i for i in range(3) if i != axis]
others = [tmp[i] for i in others_int]
# reshape by overloading c channel
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
# reshape in-plane
tmp_new_shape = [new_shape[i] for i in others_int]
data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
**{
axis_letter: orig_shape[axis + 1],
others[0]: tmp_new_shape[0],
others[1]: tmp_new_shape[1]
}
)
# reshape out of plane w/ nearest
data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
if was_numpy:
data = data.numpy()
return data
else:
return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
if __name__ == '__main__':
torch.set_num_threads(16)
img_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/imagesTr/patient041_frame01_0000.nii.gz'
seg_file = '/media/isensee/raw_data/nnUNet_raw/Dataset027_ACDC/labelsTr/patient041_frame01.nii.gz'
io = SimpleITKIO()
data, pkl = io.read_images((img_file, ))
seg, pkl = io.read_seg(seg_file)
target_shape = (15, 256, 312)
spacing = pkl['spacing']
use = data
is_seg = False
ret_nosep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg)
ret_sep = resample_torch_fornnunet(use, target_shape, spacing, spacing, is_seg, force_separate_z=False)
================================================
FILE: nnunetv2/preprocessing/resampling/utils.py
================================================
from typing import Callable
import nnunetv2
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable:
ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn,
'nnunetv2.preprocessing.resampling')
if ret is None:
raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the "
"nnunetv2.preprocessing.resampling module." % resampling_fn)
else:
return ret
================================================
FILE: nnunetv2/run/__init__.py
================================================
================================================
FILE: nnunetv2/run/load_pretrained_weights.py
================================================
import torch
from torch._dynamo import OptimizedModule
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
def load_pretrained_weights(network, fname, verbose=False):
"""
Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the
shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps)
identified by keys ending with '.seg_layers') are not transferred!
If the pretrained weights were obtained with a training outside nnU-Net and DDP or torch.optimize was used,
you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds
'_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as
nnUNetTrainer.save_checkpoint takes care of that!
"""
if dist.is_initialized():
saved_model = torch.load(fname, map_location=torch.device('cuda', dist.get_rank()), weights_only=False)
else:
saved_model = torch.load(fname, weights_only=False)
pretrained_dict = saved_model['network_weights']
skip_strings_in_pretrained = [
'.seg_layers.',
]
if isinstance(network, DDP):
mod = network.module
else:
mod = network
if isinstance(mod, OptimizedModule):
mod = mod._orig_mod
model_dict = mod.state_dict()
# verify that all but the segmentation layers have the same shape
for key, _ in model_dict.items():
if all([i not in key for i in skip_strings_in_pretrained]):
assert key in pretrained_dict, \
f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \
f"compatible with your network."
assert model_dict[key].shape == pretrained_dict[key].shape, \
f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \
f"does not seem to be compatible with your network."
# fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained
# encoders. Not supported by this function though (see assertions above)
# commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do'
# pretrained_dict = {'module.' + k if is_ddp else k: v
# for k, v in pretrained_dict.items()
# if (('module.' + k if is_ddp else k) in model_dict) and
# all([i not in k for i in skip_strings_in_pretrained])}
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}
model_dict.update(pretrained_dict)
print("################### Loading pretrained weights from file ", fname, '###################')
if verbose:
print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:")
for key, value in pretrained_dict.items():
print(key, 'shape', value.shape)
print("################### Done ###################")
mod.load_state_dict(model_dict)
================================================
FILE: nnunetv2/run/run_training.py
================================================
import multiprocessing
import os
import socket
from typing import Union, Optional
import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def get_trainer_from_args(dataset_name_or_id: Union[int, str],
configuration: str,
fold: int,
trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
continue_training: bool = False,
device: torch.device = torch.device('cuda')):
# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if nnunet_trainer is None:
raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
f'nnunetv2.training.nnUNetTrainer ('
f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
f'else, please move it there.')
assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
'nnUNetTrainer'
# handle dataset input. If it's an ID we need to convert to int from string
if dataset_name_or_id.startswith('Dataset'):
pass
else:
try:
dataset_name_or_id = int(dataset_name_or_id)
except ValueError:
raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
f'input: {dataset_name_or_id}')
# initialize nnunet trainer
preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
plans = load_json(plans_file)
plans["continue_training"] = continue_training
dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
dataset_json=dataset_json, device=device)
return nnunet_trainer
def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
pretrained_weights_file: str = None):
if continue_training and pretrained_weights_file is not None:
raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
'be used at the beginning of the training.')
if continue_training:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
# special case where --c is used to run a previously aborted validation
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
if not isfile(expected_checkpoint_file):
print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
f"continue from. Starting a new training...")
expected_checkpoint_file = None
elif validation_only:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
else:
if pretrained_weights_file is not None:
if not nnunet_trainer.was_initialized:
nnunet_trainer.initialize()
load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
expected_checkpoint_file = None
if expected_checkpoint_file is not None:
nnunet_trainer.load_checkpoint(expected_checkpoint_file)
def setup_ddp(rank, world_size):
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup_ddp():
dist.destroy_process_group()
def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, disable_checkpointing, c, val,
pretrained_weights, npz, val_with_best, world_size):
setup_ddp(rank, world_size)
torch.cuda.set_device(torch.device('cuda', dist.get_rank()))
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, c)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not val:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(npz)
cleanup_ddp()
def run_training(dataset_name_or_id: Union[str, int],
configuration: str, fold: Union[int, str],
trainer_class_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
pretrained_weights: Optional[str] = None,
num_gpus: int = 1,
export_validation_probabilities: bool = False,
continue_training: bool = False,
only_run_validation: bool = False,
disable_checkpointing: bool = False,
val_with_best: bool = False,
device: torch.device = torch.device('cuda')):
if plans_identifier == 'nnUNetPlans':
print("\n############################\n"
"INFO: You are using the old nnU-Net default plans. We have updated our recommendations. "
"Please consider using those instead! "
"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md"
"\n############################\n")
if isinstance(fold, str):
if fold != 'all':
try:
fold = int(fold)
except ValueError as e:
print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
raise e
if val_with_best:
assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing'
if num_gpus > 1:
assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ.keys():
port = str(find_free_network_port())
print(f"using port {port}")
os.environ['MASTER_PORT'] = port # str(port)
mp.spawn(run_ddp,
args=(
dataset_name_or_id,
configuration,
fold,
trainer_class_name,
plans_identifier,
disable_checkpointing,
continue_training,
only_run_validation,
pretrained_weights,
export_validation_probabilities,
val_with_best,
num_gpus),
nprocs=num_gpus,
join=True)
else:
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
plans_identifier, continue_training, device=device)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not only_run_validation:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(export_validation_probabilities)
def run_training_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name_or_id', type=str,
help="Dataset name or ID to train with")
parser.add_argument('configuration', type=str,
help="Configuration that should be trained")
parser.add_argument('fold', type=str,
help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '
'be used when actually training. Beta. Use with caution.')
parser.add_argument('-num_gpus', type=int, default=1, required=False,
help='Specify the number of GPUs to use for training')
parser.add_argument('--npz', action='store_true', required=False,
help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '
'segmentations). Needed for finding the best ensemble.')
parser.add_argument('--c', action='store_true', required=False,
help='[OPTIONAL] Continue training from latest checkpoint')
parser.add_argument('--val', action='store_true', required=False,
help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')
parser.add_argument('--val_best', action='store_true', required=False,
help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead '
'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! '
'WARNING: This will use the same \'validation\' folder as the regular validation '
'with no way of distinguishing the two!')
parser.add_argument('--disable_checkpointing', action='store_true', required=False,
help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '
'you dont want to flood your hard drive with checkpoints.')
parser.add_argument('-device', type=str, default='cuda', required=False,
help="Use this to set the device the training should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
args = parser.parse_args()
assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
else:
device = torch.device('mps')
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
args.num_gpus, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,
device=device)
if __name__ == '__main__':
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
# reduces the number of threads used for compiling. More threads don't help and can cause problems
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# multiprocessing.set_start_method("spawn")
run_training_entry()
================================================
FILE: nnunetv2/tests/__init__.py
================================================
================================================
FILE: nnunetv2/tests/integration_tests/__init__.py
================================================
================================================
FILE: nnunetv2/tests/integration_tests/add_lowres_and_cascade.py
================================================
from copy import deepcopy
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids')
args = parser.parse_args()
for d in args.d:
dataset_name = maybe_convert_to_dataset_name(d)
plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'))
plans['configurations']['3d_lowres'] = {
"data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time
'inherits_from': '3d_fullres',
"patch_size": [20, 28, 20],
"median_image_size_in_voxels": [18.0, 25.0, 18.0],
"spacing": [2.0, 2.0, 2.0],
"architecture": deepcopy(plans['configurations']['3d_fullres']["architecture"]),
"next_stage": "3d_cascade_fullres"
}
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_conv_per_stage'] = [2, 2, 2]
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_conv_per_stage_decoder'] = [2, 2]
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['strides'] = [[1, 1, 1], [2, 2, 2], [2, 2, 2]]
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['kernel_sizes'] = [[3, 3, 3], [3, 3, 3], [3, 3, 3]]
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['n_stages'] = 3
plans['configurations']['3d_lowres']['architecture']["arch_kwargs"]['features_per_stage'] = [
32,
64,
128
]
plans['configurations']['3d_cascade_fullres'] = {
'inherits_from': '3d_fullres',
"previous_stage": "3d_lowres"
}
save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False)
================================================
FILE: nnunetv2/tests/integration_tests/cleanup_integration_test.py
================================================
import shutil
from batchgenerators.utilities.file_and_folder_operations import isdir, join
from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed
if __name__ == '__main__':
# deletes everything!
dataset_names = [
'Dataset996_IntegrationTest_Hippocampus_regions_ignore',
'Dataset997_IntegrationTest_Hippocampus_regions',
'Dataset998_IntegrationTest_Hippocampus_ignore',
'Dataset999_IntegrationTest_Hippocampus',
]
for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]:
for d in dataset_names:
if isdir(join(fld, d)):
shutil.rmtree(join(fld, d))
================================================
FILE: nnunetv2/tests/integration_tests/lsf_commands.sh
================================================
bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996"
bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997"
bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998"
bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999"
bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996"
bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997"
bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998"
bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999"
================================================
FILE: nnunetv2/tests/integration_tests/prepare_integration_tests.sh
================================================
# assumes you are in the nnunet repo!
# prepare raw datasets
python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py
python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py
python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py
python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
# now run experiment planning without preprocessing
nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp
# now add 3d lowres and cascade
python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999
# now preprocess everything
nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres
# done
================================================
FILE: nnunetv2/tests/integration_tests/readme.md
================================================
# Preface
I am just a mortal with many tasks and limited time. Aint nobody got time for unittests.
HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish.
# Introduction - What the heck is happening?
This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with
ignore labels). It runs the entire nnU-Net pipeline from start to finish:
- fingerprint extraction
- experiment planning
- preprocessing
- train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV
- automatically find the best model or ensemble
- determine the postprocessing used for this
- predict some test set
- apply postprocessing to the test set
To speed things up, we do the following:
- pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation
- by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more!
- we use nnUNetTrainer_5epochs for a short training
# How to run it?
Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!)
Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004!
```commandline
bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh
```
Now you can run the integration test for each of the datasets:
```commandline
bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID
```
use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up.
This will take i dunno like 10-30 Minutes!?
Also run
```commandline
bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID
```
to verify DDP is working (needs 2 GPUs!)
# How to check if the test was successful?
If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range.
So you need to do the following:
1) check that none of your runs crashed (duh)
2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file.
Does it make sense? If so: NICE!
Once the integration test is completed you can delete all the temporary files associated with it by running:
```commandline
python nnunetv2/tests/integration_tests/cleanup_integration_test.py
```
================================================
FILE: nnunetv2/tests/integration_tests/run_integration_test.sh
================================================
nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz
nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz
python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1
================================================
FILE: nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py
================================================
import argparse
import torch
from batchgenerators.utilities.file_and_folder_operations import join, load_pickle
from nnunetv2.ensembling.ensemble import ensemble_folders
from nnunetv2.evaluation.find_best_configuration import find_best_configuration, \
dumb_trainer_config_plans_to_trained_models_dict
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.paths import nnUNet_raw, nnUNet_results
from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.file_path_utilities import get_output_folder
if __name__ == '__main__':
"""
Predicts the imagesTs folder with the best configuration and applies postprocessing
"""
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
parser = argparse.ArgumentParser()
parser.add_argument('-d', type=int, help='dataset id')
args = parser.parse_args()
d = args.d
dataset_name = maybe_convert_to_dataset_name(d)
source_dir = join(nnUNet_raw, dataset_name, 'imagesTs')
target_dir_base = join(nnUNet_results, dataset_name)
models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'],
['2d',
'3d_lowres',
'3d_cascade_fullres',
'3d_fullres'],
['nnUNetPlans'])
ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True,
folds=(0, 1, 2, 3, 4), strict=True)
has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1
# we don't use all folds to speed stuff up
used_folds = (0, 3)
output_folders = []
for im in ret['best_model_or_ensemble']['selected_model_or_models']:
output_dir = join(target_dir_base, f"pred_{im['configuration']}")
model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration'])
# note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted
# twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the
# prediction here. Proper way would be to check for that and
# then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in
# predict_from_raw_data. Since we allow for
# dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an
# integration test after all. Take a closer look at how this in handled in predict_from_raw_data
predictor = nnUNetPredictor(verbose=False, allow_tqdm=False)
predictor.initialize_from_trained_model_folder(model_folder, used_folds)
predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True)
# predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir,
# model_training_output_dir=model_folder, use_folds=used_folds,
# save_probabilities=has_ensemble, verbose=False, overwrite=True)
output_folders.append(output_dir)
# if we have an ensemble, we need to ensemble the results
if has_ensemble:
ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False)
folder_for_pp = join(target_dir_base, 'ensemble_predictions')
else:
folder_for_pp = output_folders[0]
# apply postprocessing
pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file'])
apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'),
pp_fns,
pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file'])
================================================
FILE: nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh
================================================
nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2
================================================
FILE: nnunetv2/tests/integration_tests/run_nnunet_inference.py
================================================
import os
import shutil
import subprocess
from pathlib import Path
import nibabel as nib
import numpy as np
def dice_score(y_true, y_pred):
intersect = np.sum(y_true * y_pred)
denominator = np.sum(y_true) + np.sum(y_pred)
f1 = (2 * intersect) / (denominator + 1e-6)
return f1
def run_tests_and_exit_on_failure():
"""
Runs inference of a simple nnU-Net for CT body segmentation on a small example CT image
and checks if the output is correct.
"""
# Set nnUNet_results env var
weights_dir = Path.home() / "github_actions_nnunet" / "results"
os.environ["nnUNet_results"] = str(weights_dir)
# Copy example file
os.makedirs("nnunetv2/tests/github_actions_output", exist_ok=True)
shutil.copy("nnunetv2/tests/example_data/example_ct_sm.nii.gz", "nnunetv2/tests/github_actions_output/example_ct_sm_0000.nii.gz")
# Run nnunet
subprocess.call(f"nnUNetv2_predict -i nnunetv2/tests/github_actions_output -o nnunetv2/tests/github_actions_output -d 300 -tr nnUNetTrainer -c 3d_fullres -f 0 -device cpu", shell=True)
# Check if the nnunet segmentation is correct
img_gt = nib.load(f"nnunetv2/tests/example_data/example_ct_sm_T300_output.nii.gz").get_fdata()
img_pred = nib.load(f"nnunetv2/tests/github_actions_output/example_ct_sm.nii.gz").get_fdata()
dice = dice_score(img_gt, img_pred)
images_equal = dice > 0.99 # allow for a small difference in the segmentation, otherwise the test will fail often
assert images_equal, f"The nnunet segmentation is not correct (dice: {dice:.5f})."
# Clean up
shutil.rmtree("nnunetv2/tests/github_actions_output")
shutil.rmtree(Path.home() / "github_actions_nnunet")
if __name__ == "__main__":
run_tests_and_exit_on_failure()
================================================
FILE: nnunetv2/training/__init__.py
================================================
================================================
FILE: nnunetv2/training/data_augmentation/__init__.py
================================================
================================================
FILE: nnunetv2/training/data_augmentation/compute_initial_patch_size.py
================================================
import numpy as np
def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
coords = np.array(final_patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(int)
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/__init__.py
================================================
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py
================================================
from typing import Union, List, Tuple, Callable
import numpy as np
from acvl_utils.morphology.morphology_helper import label_with_component_sizes
from batchgenerators.transforms.abstract_transforms import AbstractTransform
from skimage.morphology import ball
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
class MoveSegAsOneHotToData(AbstractTransform):
def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]],
key_origin="seg", key_target="data", remove_from_origin=True):
"""
Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to
data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg].
"""
self.remove_from_origin = remove_from_origin
self.all_labels = all_labels
self.key_target = key_target
self.key_origin = key_origin
self.index_in_origin = index_in_origin
def __call__(self, **data_dict):
seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1]
seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]),
dtype=data_dict[self.key_target].dtype)
for i, l in enumerate(self.all_labels):
seg_onehot[:, i][seg[:, 0] == l] = 1
data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1)
if self.remove_from_origin:
remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin]
data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels]
return data_dict
class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
def __init__(self, channel_idx: Union[int, List[int]], key: str = "data", p_per_sample: float = 0.2,
fill_with_other_class_p: float = 0.25,
dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1):
"""
Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components
smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating
misclassification as another class (fill_with_other_class_p)
"""
self.p_per_label = p_per_label
self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent
self.fill_with_other_class_p = fill_with_other_class_p
self.p_per_sample = p_per_sample
self.key = key
if not isinstance(channel_idx, (list, tuple)):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
data = data_dict.get(self.key)
for b in range(data.shape[0]):
if np.random.uniform() < self.p_per_sample:
for c in self.channel_idx:
if np.random.uniform() < self.p_per_label:
# print(np.unique(data[b, c])) ## should be [0, 1]
workon = data[b, c].astype(bool)
if not np.any(workon):
continue
num_voxels = np.prod(workon.shape, dtype=np.uint64)
lab, component_sizes = label_with_component_sizes(workon.astype(bool))
if len(component_sizes) > 0:
valid_component_ids = [i for i, j in component_sizes.items() if j <
num_voxels*self.dont_do_if_covers_more_than_x_percent]
# print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c,
# np.unique(data[b, c]), len(component_sizes), valid_component_ids,
# len(valid_component_ids))
if len(valid_component_ids) > 0:
random_component = np.random.choice(valid_component_ids)
data[b, c][lab == random_component] = 0
if np.random.uniform() < self.fill_with_other_class_p:
other_ch = [i for i in self.channel_idx if i != c]
if len(other_ch) > 0:
other_class = np.random.choice(other_ch)
data[b, other_class][lab == random_component] = 1
data_dict[self.key] = data
return data_dict
class ApplyRandomBinaryOperatorTransform(AbstractTransform):
def __init__(self,
channel_idx: Union[int, List[int], Tuple[int, ...]],
p_per_sample: float = 0.3,
any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening),
key: str = "data",
strel_size: Tuple[int, int] = (1, 10),
p_per_label: float = 1):
"""
Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled
from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded
segmentation (see for example MoveSegAsOneHotToData)
"""
self.p_per_label = p_per_label
self.strel_size = strel_size
self.key = key
self.any_of_these = any_of_these
self.p_per_sample = p_per_sample
if not isinstance(channel_idx, (list, tuple)):
channel_idx = [channel_idx]
self.channel_idx = channel_idx
def __call__(self, **data_dict):
for b in range(data_dict[self.key].shape[0]):
if np.random.uniform() < self.p_per_sample:
# this needs to be applied in random order to the channels
np.random.shuffle(self.channel_idx)
for c in self.channel_idx:
if np.random.uniform() < self.p_per_label:
operation = np.random.choice(self.any_of_these)
selem = ball(np.random.uniform(*self.strel_size))
workon = data_dict[self.key][b, c].astype(bool)
if not np.any(workon):
continue
# print(np.unique(workon))
res = operation(workon, selem).astype(data_dict[self.key].dtype)
# print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res))
data_dict[self.key][b, c] = res
# if class was added, we need to remove it in ALL other channels to keep one hot encoding
# properties
other_ch = [i for i in self.channel_idx if i != c]
if len(other_ch) > 0:
was_added_mask = (res - workon) > 0
for oc in other_ch:
data_dict[self.key][b, oc][was_added_mask] = 0
# if class was removed, leave it at background
return data_dict
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py
================================================
from typing import Tuple, Union, List
from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
class DownsampleSegForDSTransform2(AbstractTransform):
'''
data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
'''
def __init__(self, ds_scales: Union[List, Tuple],
order: int = 0, input_key: str = "seg",
output_key: str = "seg", axes: Tuple[int] = None):
"""
Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
for each axis independently
"""
self.axes = axes
self.output_key = output_key
self.input_key = input_key
self.order = order
self.ds_scales = ds_scales
def __call__(self, **data_dict):
if self.axes is None:
axes = list(range(2, data_dict[self.input_key].ndim))
else:
axes = self.axes
output = []
for s in self.ds_scales:
if not isinstance(s, (tuple, list)):
s = [s] * len(axes)
else:
assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
f'for each axis) then the number of entried in that tuple (here ' \
f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
if all([i == 1 for i in s]):
output.append(data_dict[self.input_key])
else:
new_shape = np.array(data_dict[self.input_key].shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
for b in range(data_dict[self.input_key].shape[0]):
for c in range(data_dict[self.input_key].shape[1]):
out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
output.append(out_seg)
data_dict[self.output_key] = output
return data_dict
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/masking.py
================================================
from typing import List
from batchgenerators.transforms.abstract_transforms import AbstractTransform
class MaskTransform(AbstractTransform):
def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0,
data_key: str = "data", seg_key: str = "seg"):
"""
Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!
"""
self.apply_to_channels = apply_to_channels
self.seg_key = seg_key
self.data_key = data_key
self.set_outside_to = set_outside_to
self.mask_idx_in_seg = mask_idx_in_seg
def __call__(self, **data_dict):
mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0
for c in self.apply_to_channels:
data_dict[self.data_key][:, c][mask] = self.set_outside_to
return data_dict
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py
================================================
from typing import List, Tuple, Union
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
class ConvertSegmentationToRegionsTransform(AbstractTransform):
def __init__(self, regions: Union[List, Tuple],
seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
"""
regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,
example:
regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
:param regions:
:param seg_key:
:param output_key:
"""
self.seg_channel = seg_channel
self.output_key = output_key
self.seg_key = seg_key
self.regions = regions
def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
if seg is not None:
b, c, *shape = seg.shape
region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
for region_id, region_labels in enumerate(self.regions):
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
return data_dict
================================================
FILE: nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py
================================================
from typing import Tuple, Union, List
from batchgenerators.transforms.abstract_transforms import AbstractTransform
class Convert3DTo2DTransform(AbstractTransform):
def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
"""
Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel
"""
self.apply_to_keys = apply_to_keys
def __call__(self, **data_dict):
for k in self.apply_to_keys:
shp = data_dict[k].shape
assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.'
data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
shape_key = f'orig_shape_{k}'
assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \
f'It does that using the {shape_key} key. That key is ' \
f'already taken. Bummer.'
data_dict[shape_key] = shp
return data_dict
class Convert2DTo3DTransform(AbstractTransform):
def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
"""
Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z)
"""
self.apply_to_keys = apply_to_keys
def __call__(self, **data_dict):
for k in self.apply_to_keys:
shape_key = f'orig_shape_{k}'
assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \
f'Convert2DTo3DTransform only works in tandem with ' \
f'Convert3DTo2DTransform and you probably forgot to add ' \
f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \
f'is where the missing key is generated)'
original_shape = data_dict[shape_key]
current_shape = data_dict[k].shape
data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2],
current_shape[-2], current_shape[-1]))
return data_dict
================================================
FILE: nnunetv2/training/dataloading/__init__.py
================================================
================================================
FILE: nnunetv2/training/dataloading/data_loader.py
================================================
import os
import warnings
from typing import Union, Tuple, List
import numpy as np
import torch
from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.utilities.file_and_folder_operations import join, load_json
from threadpoolctl import threadpool_limits
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetBaseDataset
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDatasetBlosc2
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from acvl_utils.cropping_and_padding.bounding_boxes import crop_and_pad_nd
class nnUNetDataLoader(DataLoader):
def __init__(self,
data: nnUNetBaseDataset,
batch_size: int,
patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
label_manager: LabelManager,
oversample_foreground_percent: float = 0.0,
sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None,
pad_sides: Union[List[int], Tuple[int, ...]] = None,
probabilistic_oversampling: bool = False,
transforms=None):
"""
If we get a 2D patch size, make it pseudo 3D and remember to remove the singleton dimension before
returning the batch
"""
super().__init__(data, batch_size, 1, None, True,
False, True, sampling_probabilities)
if len(patch_size) == 2:
final_patch_size = (1, *final_patch_size)
patch_size = (1, *patch_size)
self.patch_size_was_2d = True
else:
self.patch_size_was_2d = False
# this is used by DataLoader for sampling train cases!
self.indices = data.identifiers
self.oversample_foreground_percent = oversample_foreground_percent
self.final_patch_size = final_patch_size
self.patch_size = patch_size
# need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
# (which is what the network will get) these patches will also cover the border of the images
self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
if pad_sides is not None:
if self.patch_size_was_2d:
pad_sides = (0, *pad_sides)
for d in range(len(self.need_to_pad)):
self.need_to_pad[d] += pad_sides[d]
self.num_channels = None
self.pad_sides = pad_sides
self.sampling_probabilities = sampling_probabilities
self.annotated_classes_key = tuple([-1] + label_manager.all_labels)
self.has_ignore = label_manager.has_ignore_label
self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \
else self._probabilistic_oversampling
self.transforms = transforms
self.data_shape, self.seg_shape = self.determine_shapes()
def _oversample_last_XX_percent(self, sample_idx: int) -> bool:
"""
determines whether sample sample_idx in a minibatch needs to be guaranteed foreground
"""
return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
def _probabilistic_oversampling(self, sample_idx: int) -> bool:
# print('YEAH BOIIIIII')
return np.random.uniform() < self.oversample_foreground_percent
def determine_shapes(self):
# load one case
data, seg, seg_prev, properties = self._data.load_case(self._data.identifiers[0])
num_color_channels = data.shape[0]
if self.patch_size_was_2d:
spatial_shape = self.final_patch_size[1:] if self.transforms is not None else self.patch_size[1:]
else:
spatial_shape = self.final_patch_size if self.transforms is not None else self.patch_size
data_shape = (self.batch_size, num_color_channels, *spatial_shape)
channels_seg = seg.shape[0]
if seg_prev is not None:
channels_seg += 1
seg_shape = (self.batch_size, channels_seg, *spatial_shape)
return data_shape, seg_shape
def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],
overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):
# in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have
# locations for the given slice
need_to_pad = self.need_to_pad.copy()
dim = len(data_shape)
for d in range(dim):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + data_shape[d] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - data_shape[d]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
lbs = [- need_to_pad[i] // 2 for i in range(dim)]
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg and not self.has_ignore:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
# print('I want a random location')
else:
if not force_fg and self.has_ignore:
selected_class = self.annotated_classes_key
if len(class_locations[selected_class]) == 0:
# no annotated pixels in this case. Not good. But we can hardly skip it here
warnings.warn('Warning! No annotated pixels in image!')
selected_class = None
elif force_fg:
assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
if overwrite_class is not None:
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
'have class_locations (missing key)'
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
# class_locations keys can also be tuple
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
if len(eligible_classes_or_regions) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
if verbose:
print('case does not contain any foreground classes')
else:
# I hate myself. Future me aint gonna be happy to read this
# 2022_11_25: had to read it today. Wasn't too bad
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
# print(f'I want to have foreground, selected class: {selected_class}')
else:
raise RuntimeError('lol what!?')
if selected_class is not None:
voxels_of_that_class = class_locations[selected_class]
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
# i + 1 because we have first dimension 0!
bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]
return bbox_lbs, bbox_ubs
def generate_train_batch(self):
selected_keys = self.get_indices()
# preallocate output tensors in final patch size and write transformed samples directly
data_all = torch.empty(self.data_shape, dtype=torch.float32)
seg_all = None
with torch.no_grad():
with threadpool_limits(limits=1, user_api=None):
for j, i in enumerate(selected_keys):
# oversampling foreground will improve stability of model training, especially if many patches are empty
# (Lung for example)
force_fg = self.get_do_oversample(j)
data, seg, seg_prev, properties = self._data.load_case(i)
# If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
# self._data.load_case(i) (see nnUNetDataset.load_case)
shape = data.shape[1:]
bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])
bbox = [[i, j] for i, j in zip(bbox_lbs, bbox_ubs)]
data_cropped = torch.from_numpy(crop_and_pad_nd(data, bbox, 0)).float()
seg_cropped = torch.from_numpy(crop_and_pad_nd(seg, bbox, -1)).to(torch.int16)
if seg_prev is not None:
seg_prev_cropped = torch.from_numpy(crop_and_pad_nd(seg_prev, bbox, -1)).to(torch.int16)
seg_cropped = torch.cat((seg_cropped, seg_prev_cropped[None]), dim=0)
if self.patch_size_was_2d:
data_cropped = data_cropped[:, 0]
seg_cropped = seg_cropped[:, 0]
if self.transforms is not None:
transformed = self.transforms(**{'image': data_cropped, 'segmentation': seg_cropped})
data_sample = transformed['image']
seg_sample = transformed['segmentation']
else:
data_sample = data_cropped
seg_sample = seg_cropped
data_all[j] = data_sample
if isinstance(seg_sample, list):
if seg_all is None:
seg_all = [torch.empty((self.batch_size, *s.shape), dtype=s.dtype) for s in seg_sample]
for s_idx, s in enumerate(seg_sample):
seg_all[s_idx][j] = s
else:
if seg_all is None:
seg_all = torch.empty((self.batch_size, *seg_sample.shape), dtype=seg_sample.dtype)
seg_all[j] = seg_sample
return {'data': data_all, 'target': seg_all, 'keys': selected_keys}
if __name__ == '__main__':
folder = join(nnUNet_preprocessed, 'Dataset002_Heart', 'nnUNetPlans_3d_fullres')
ds = nnUNetDatasetBlosc2(folder) # this should not load the properties!
pm = PlansManager(join(folder, os.pardir, 'nnUNetPlans.json'))
lm = pm.get_label_manager(load_json(join(folder, os.pardir, 'dataset.json')))
dl = nnUNetDataLoader(ds, 5, (16, 16, 16), (16, 16, 16), lm,
0.33, None, None)
a = next(dl)
================================================
FILE: nnunetv2/training/dataloading/nnunet_dataset.py
================================================
import os
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import lru_cache
from typing import List, Union, Type, Tuple
import numpy as np
import blosc2
import shutil
from blosc2 import Filter, Codec
from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile, write_pickle, subfiles
from nnunetv2.configuration import default_num_processes
from nnunetv2.training.dataloading.utils import unpack_dataset
import math
class nnUNetBaseDataset(ABC):
"""
Defines the interface
"""
def __init__(self, folder: str, identifiers: List[str] = None,
folder_with_segs_from_previous_stage: str = None):
super().__init__()
# print('loading dataset')
if identifiers is None:
identifiers = self.get_identifiers(folder)
identifiers.sort()
self.source_folder = folder
self.folder_with_segs_from_previous_stage = folder_with_segs_from_previous_stage
self.identifiers = identifiers
def __getitem__(self, identifier):
return self.load_case(identifier)
@abstractmethod
def load_case(self, identifier):
pass
@staticmethod
@abstractmethod
def save_case(
data: np.ndarray,
seg: np.ndarray,
properties: dict,
output_filename_truncated: str
):
pass
@staticmethod
@abstractmethod
def get_identifiers(folder: str) -> List[str]:
pass
@staticmethod
def unpack_dataset(folder: str, overwrite_existing: bool = False,
num_processes: int = default_num_processes,
verify: bool = True):
pass
class nnUNetDatasetNumpy(nnUNetBaseDataset):
def load_case(self, identifier):
data_npy_file = join(self.source_folder, identifier + '.npy')
if not isfile(data_npy_file):
data = np.load(join(self.source_folder, identifier + '.npz'))['data']
else:
data = np.load(data_npy_file, mmap_mode='r')
seg_npy_file = join(self.source_folder, identifier + '_seg.npy')
if not isfile(seg_npy_file):
seg = np.load(join(self.source_folder, identifier + '.npz'))['seg']
else:
seg = np.load(seg_npy_file, mmap_mode='r')
if self.folder_with_segs_from_previous_stage is not None:
prev_seg_npy_file = join(self.folder_with_segs_from_previous_stage, identifier + '.npy')
if isfile(prev_seg_npy_file):
seg_prev = np.load(prev_seg_npy_file, 'r')
else:
seg_prev = np.load(join(self.folder_with_segs_from_previous_stage, identifier + '.npz'))['seg']
else:
seg_prev = None
properties = load_pickle(join(self.source_folder, identifier + '.pkl'))
return data, seg, seg_prev, properties
@staticmethod
def save_case(
data: np.ndarray,
seg: np.ndarray,
properties: dict,
output_filename_truncated: str
):
np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg)
write_pickle(properties, output_filename_truncated + '.pkl')
@staticmethod
def save_seg(
seg: np.ndarray,
output_filename_truncated: str
):
np.savez_compressed(output_filename_truncated + '.npz', seg=seg)
@staticmethod
def get_identifiers(folder: str) -> List[str]:
"""
returns all identifiers in the preprocessed data folder
"""
case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
return case_identifiers
@staticmethod
def unpack_dataset(folder: str, overwrite_existing: bool = False,
num_processes: int = default_num_processes,
verify: bool = True):
return unpack_dataset(folder, True, overwrite_existing, num_processes, verify)
class nnUNetDatasetBlosc2(nnUNetBaseDataset):
def __init__(self, folder: str, identifiers: List[str] = None,
folder_with_segs_from_previous_stage: str = None):
super().__init__(folder, identifiers, folder_with_segs_from_previous_stage)
blosc2.set_nthreads(1)
def __getitem__(self, identifier):
return self.load_case(identifier)
def load_case(self, identifier):
dparams = {
'nthreads': 1
}
data_b2nd_file = join(self.source_folder, identifier + '.b2nd')
# mmap does not work with Windows -> https://github.com/MIC-DKFZ/nnUNet/issues/2723
mmap_kwargs = {} if os.name == "nt" else {'mmap_mode': 'r'}
data = blosc2.open(urlpath=data_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)
seg_b2nd_file = join(self.source_folder, identifier + '_seg.b2nd')
seg = blosc2.open(urlpath=seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)
if self.folder_with_segs_from_previous_stage is not None:
prev_seg_b2nd_file = join(self.folder_with_segs_from_previous_stage, identifier + '.b2nd')
seg_prev = blosc2.open(urlpath=prev_seg_b2nd_file, mode='r', dparams=dparams, **mmap_kwargs)
else:
seg_prev = None
properties = load_pickle(join(self.source_folder, identifier + '.pkl'))
return data, seg, seg_prev, properties
@staticmethod
def save_case(
data: np.ndarray,
seg: np.ndarray,
properties: dict,
output_filename_truncated: str,
chunks=None,
blocks=None,
chunks_seg=None,
blocks_seg=None,
clevel: int = 8,
codec=blosc2.Codec.ZSTD
):
blosc2.set_nthreads(1)
if chunks_seg is None:
chunks_seg = chunks
if blocks_seg is None:
blocks_seg = blocks
cparams = {
'codec': codec,
# 'filters': [blosc2.Filter.SHUFFLE],
# 'splitmode': blosc2.SplitMode.ALWAYS_SPLIT,
'clevel': clevel,
}
# print(output_filename_truncated, data.shape, seg.shape, blocks, chunks, blocks_seg, chunks_seg, data.dtype, seg.dtype)
blosc2.asarray(np.ascontiguousarray(data), urlpath=output_filename_truncated + '.b2nd', chunks=chunks,
blocks=blocks, cparams=cparams)
blosc2.asarray(np.ascontiguousarray(seg), urlpath=output_filename_truncated + '_seg.b2nd', chunks=chunks_seg,
blocks=blocks_seg, cparams=cparams)
write_pickle(properties, output_filename_truncated + '.pkl')
@staticmethod
def save_seg(
seg: np.ndarray,
output_filename_truncated: str,
chunks_seg=None,
blocks_seg=None
):
blosc2.asarray(seg, urlpath=output_filename_truncated + '.b2nd', chunks=chunks_seg, blocks=blocks_seg)
@staticmethod
def get_identifiers(folder: str) -> List[str]:
"""
returns all identifiers in the preprocessed data folder
"""
case_identifiers = [i[:-5] for i in os.listdir(folder) if i.endswith(".b2nd") and not i.endswith("_seg.b2nd")]
return case_identifiers
@staticmethod
def unpack_dataset(folder: str, overwrite_existing: bool = False,
num_processes: int = default_num_processes,
verify: bool = True):
pass
@staticmethod
def comp_blosc2_params(
image_size: Tuple[int, int, int, int],
patch_size: Union[Tuple[int, int], Tuple[int, int, int]],
bytes_per_pixel: int = 4, # 4 byte are float32
l1_cache_size_per_core_in_bytes=32768, # 1 Kibibyte (KiB) = 2^10 Byte; 32 KiB = 32768 Byte
l3_cache_size_per_core_in_bytes=1441792,
# 1 Mibibyte (MiB) = 2^20 Byte = 1.048.576 Byte; 1.375MiB = 1441792 Byte
safety_factor: float = 0.8 # we dont will the caches to the brim. 0.8 means we target 80% of the caches
):
"""
Computes a recommended block and chunk size for saving arrays with blosc v2.
Bloscv2 NDIM doku: "Remember that having a second partition means that we have better flexibility to fit the
different partitions at the different CPU cache levels; typically the first partition (aka chunks) should
be made to fit in L3 cache, whereas the second partition (aka blocks) should rather fit in L2/L1 caches
(depending on whether compression ratio or speed is desired)."
(https://www.blosc.org/posts/blosc2-ndim-intro/)
-> We are not 100% sure how to optimize for that. For now we try to fit the uncompressed block in L1. This
might spill over into L2, which is fine in our books.
Note: this is optimized for nnU-Net dataloading where each read operation is done by one core. We cannot use threading
Cache default values computed based on old Intel 4110 CPU with 32K L1, 128K L2 and 1408K L3 cache per core.
We cannot optimize further for more modern CPUs with more cache as the data will need be be read by the
old ones as well.
Args:
patch_size: Image size, must be 4D (c, x, y, z). For 2D images, make x=1
patch_size: Patch size, spatial dimensions only. So (x, y) or (x, y, z)
bytes_per_pixel: Number of bytes per element. Example: float32 -> 4 bytes
l1_cache_size_per_core_in_bytes: The size of the L1 cache per core in Bytes.
l3_cache_size_per_core_in_bytes: The size of the L3 cache exclusively accessible by each core. Usually the global size of the L3 cache divided by the number of cores.
Returns:
The recommended block and the chunk size.
"""
# Fabians code is ugly, but eh
num_channels = image_size[0]
if len(patch_size) == 2:
patch_size = [1, *patch_size]
patch_size = np.array(patch_size)
block_size = np.array((num_channels, *[2 ** (max(0, math.ceil(math.log2(i)))) for i in patch_size]))
# shrink the block size until it fits in L1
estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel
while estimated_nbytes_block > (l1_cache_size_per_core_in_bytes * safety_factor):
# pick largest deviation from patch_size that is not 1
axis_order = np.argsort(block_size[1:] / patch_size)[::-1]
idx = 0
picked_axis = axis_order[idx]
while block_size[picked_axis + 1] == 1 or block_size[picked_axis + 1] == 1:
idx += 1
picked_axis = axis_order[idx]
# now reduce that axis to the next lowest power of 2
block_size[picked_axis + 1] = 2 ** (max(0, math.floor(math.log2(block_size[picked_axis + 1] - 1))))
block_size[picked_axis + 1] = min(block_size[picked_axis + 1], image_size[picked_axis + 1])
estimated_nbytes_block = np.prod(block_size) * bytes_per_pixel
block_size = np.array([min(i, j) for i, j in zip(image_size, block_size)])
# note: there is no use extending the chunk size to 3d when we have a 2d patch size! This would unnecessarily
# load data into L3
# now tile the blocks into chunks until we hit image_size or the l3 cache per core limit
chunk_size = deepcopy(block_size)
estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel
while estimated_nbytes_chunk < (l3_cache_size_per_core_in_bytes * safety_factor):
if patch_size[0] == 1 and all([i == j for i, j in zip(chunk_size[2:], image_size[2:])]):
break
if all([i == j for i, j in zip(chunk_size, image_size)]):
break
# find axis that deviates from block_size the most
axis_order = np.argsort(chunk_size[1:] / block_size[1:])
idx = 0
picked_axis = axis_order[idx]
while chunk_size[picked_axis + 1] == image_size[picked_axis + 1] or patch_size[picked_axis] == 1:
idx += 1
picked_axis = axis_order[idx]
chunk_size[picked_axis + 1] += block_size[picked_axis + 1]
chunk_size[picked_axis + 1] = min(chunk_size[picked_axis + 1], image_size[picked_axis + 1])
estimated_nbytes_chunk = np.prod(chunk_size) * bytes_per_pixel
if np.mean([i / j for i, j in zip(chunk_size[1:], patch_size)]) > 1.5:
# chunk size should not exceed patch size * 1.5 on average
chunk_size[picked_axis + 1] -= block_size[picked_axis + 1]
break
# better safe than sorry
chunk_size = [min(i, j) for i, j in zip(image_size, chunk_size)]
# print(image_size, chunk_size, block_size)
return tuple(block_size), tuple(chunk_size)
file_ending_dataset_mapping = {
'npz': nnUNetDatasetNumpy,
'b2nd': nnUNetDatasetBlosc2
}
def infer_dataset_class(folder: str) -> Union[Type[nnUNetDatasetBlosc2], Type[nnUNetDatasetNumpy]]:
file_endings = set([os.path.basename(i).split('.')[-1] for i in subfiles(folder, join=False)])
if 'pkl' in file_endings:
file_endings.remove('pkl')
if 'npy' in file_endings:
file_endings.remove('npy')
assert len(file_endings) == 1, (f'Found more than one file ending in the folder {folder}. '
f'Unable to infer nnUNetDataset variant!')
return file_ending_dataset_mapping[list(file_endings)[0]]
================================================
FILE: nnunetv2/training/dataloading/utils.py
================================================
from __future__ import annotations
import multiprocessing
import os
from typing import List
from pathlib import Path
from warnings import warn
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles
from nnunetv2.configuration import default_num_processes
def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,
verify_npy: bool = False, fail_ctr: int = 0) -> None:
data_npy = npz_file[:-3] + "npy"
seg_npy = npz_file[:-4] + "_seg.npy"
try:
npz_content = None # will only be opened on demand
if overwrite_existing or not isfile(data_npy):
try:
npz_content = np.load(npz_file) if npz_content is None else npz_content
except Exception as e:
print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!")
raise e
np.save(data_npy, npz_content['data'])
if unpack_segmentation and (overwrite_existing or not isfile(seg_npy)):
try:
npz_content = np.load(npz_file) if npz_content is None else npz_content
except Exception as e:
print(f"Unable to open preprocessed file {npz_file}. Rerun nnUNetv2_preprocess!")
raise e
np.save(npz_file[:-4] + "_seg.npy", npz_content['seg'])
if verify_npy:
try:
np.load(data_npy, mmap_mode='r')
if isfile(seg_npy):
np.load(seg_npy, mmap_mode='r')
except ValueError:
os.remove(data_npy)
os.remove(seg_npy)
print(f"Error when checking {data_npy} and {seg_npy}, fixing...")
if fail_ctr < 2:
_convert_to_npy(npz_file, unpack_segmentation, overwrite_existing, verify_npy, fail_ctr+1)
else:
raise RuntimeError("Unable to fix unpacking. Please check your system or rerun nnUNetv2_preprocess")
except KeyboardInterrupt:
if isfile(data_npy):
os.remove(data_npy)
if isfile(seg_npy):
os.remove(seg_npy)
raise KeyboardInterrupt
def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,
num_processes: int = default_num_processes,
verify: bool = False):
"""
all npz files in this folder belong to the dataset, unpack them all
"""
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
npz_files = subfiles(folder, True, None, ".npz", True)
p.starmap(_convert_to_npy, zip(npz_files,
[unpack_segmentation] * len(npz_files),
[overwrite_existing] * len(npz_files),
[verify] * len(npz_files))
)
if __name__ == '__main__':
unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d')
================================================
FILE: nnunetv2/training/logging/__init__.py
================================================
================================================
FILE: nnunetv2/training/logging/nnunet_logger.py
================================================
import matplotlib
from batchgenerators.utilities.file_and_folder_operations import join
matplotlib.use('agg')
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Any
from pathlib import Path
import shutil
import os
try:
import wandb
except ImportError:
wandb = None
def get_cluster_job_id():
job_id = None
if "LSB_JOBID" in os.environ:
job_id = os.environ["LSB_JOBID"]
if "SLURM_JOB_ID" in os.environ:
job_id = os.environ["SLURM_JOB_ID"]
return job_id
class MetaLogger(object):
"""A meta logger that bundles multiple loggers behind a single interface.
The default configuration includes a local logger used for reading values,
plotting progress, and checkpointing.
"""
def __init__(self, output_folder, resume, verbose: bool = False):
"""Initialize the meta logger.
Args:
output_folder: The output folder.
resume: Whether to resume training if possible.
verbose: Whether to enable verbose logging in the local logger.
"""
self.output_folder = output_folder
self.resume = resume
self.loggers = []
self.local_logger = LocalLogger(verbose)
if self._is_logger_enabled("nnUNet_wandb_enabled"):
self.loggers.append(WandbLogger(output_folder, resume))
def update_config(self, config: dict):
"""Add a new or update an existing experiment configuration to the logger.
Args:
config: Logger configuration options.
"""
for logger in self.loggers:
logger.update_config(config)
def log(self, key: str, value: Any, step: int):
"""Log a value for a given step.
Args:
key: Metric or field name.
value: Value to log.
step: Step index (typically epoch).
"""
self.local_logger.log(key, value, step)
if isinstance(value, list):
for i, val in enumerate(value):
for logger in self.loggers:
logger.log(f"{key}/class_{i+1}", val, step)
else:
for logger in self.loggers:
logger.log(key, value, step)
# handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice
if key == 'mean_fg_dice':
new_ema_pseudo_dice = self.get_value('ema_fg_dice', step=step-1) * 0.9 + 0.1 * value \
if len(self.get_value('ema_fg_dice', step=None)) > 0 else value
self.log('ema_fg_dice', new_ema_pseudo_dice, step)
def log_summary(self, key: str, value: Any):
"""Log a summary value. These are usually values that are not logged every step but only once.
This can be for example the final validation Dice.
Args:
key: Metric or field name.
value: Value to summarize.
"""
for logger in self.loggers:
logger.log_summary(key, value)
def get_value(self, key: str, step: Any):
"""Fetch a logged value from the local logger.
Args:
key: Metric or field name.
step: Step index to retrieve, or None to return all values.
Returns:
The logged value or list of values from the local logger.
"""
return self.local_logger.get_value(key, step)
def plot_progress_png(self, output_folder: str):
"""Write a progress plot PNG using local logger data.
Args:
output_folder: Directory where the plot image is saved.
"""
self.local_logger.plot_progress_png(output_folder)
def get_checkpoint(self):
"""Return the local logger checkpoint data.
Returns:
The checkpoint payload used to restore logging state.
"""
return self.local_logger.get_checkpoint()
def load_checkpoint(self, checkpoint: dict):
"""Restore the local logger from a checkpoint payload.
Args:
checkpoint: Checkpoint data returned by `get_checkpoint`.
"""
self.local_logger.load_checkpoint(checkpoint)
def _is_logger_enabled(self, env_var):
env_var_result = str(os.getenv(env_var, "0"))
if env_var_result in ("0", "False", "false"):
return False
elif env_var_result in ("1", "True", "true"):
return True
else:
raise RuntimeError("nnU-Net logger environement variable has the wrong value. Must be '0' (disabled) or '1'(enabled).")
class LocalLogger:
"""
This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems
arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a
little
YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP
"""
def __init__(self, verbose: bool = False):
self.my_fantastic_logging = {
'mean_fg_dice': list(),
'ema_fg_dice': list(),
'dice_per_class_or_region': list(),
'train_losses': list(),
'val_losses': list(),
'lrs': list(),
'epoch_start_timestamps': list(),
'epoch_end_timestamps': list()
}
self.verbose = verbose
# shut up, this logging is great
def log(self, key, value, epoch: int):
"""
sometimes shit gets messed up. We try to catch that here
"""
assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \
'This function is only intended to log stuff to lists and to have one entry per epoch'
if self.verbose: print(f'logging {key}: {value} for epoch {epoch}')
if len(self.my_fantastic_logging[key]) < (epoch + 1):
self.my_fantastic_logging[key].append(value)
else:
assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \
'lists length is off by more than 1'
print(f'maybe some logging issue!? logging {key} and {value}')
self.my_fantastic_logging[key][epoch] = value
def get_value(self, key, step):
if step is not None:
return self.my_fantastic_logging[key][step]
else:
return self.my_fantastic_logging[key]
def plot_progress_png(self, output_folder):
# we infer the epoch form our internal logging
epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1
sns.set(font_scale=2.5)
fig, ax_all = plt.subplots(3, 1, figsize=(30, 54))
# regular progress.png as we are used to from previous nnU-Net versions
ax = ax_all[0]
ax2 = ax.twinx()
x_values = list(range(epoch + 1))
ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4)
ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4)
ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice",
linewidth=3)
ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)",
linewidth=4)
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax2.set_ylabel("pseudo dice")
ax.legend(loc=(0, 1))
ax2.legend(loc=(0.2, 1))
# epoch times to see whether the training speed is consistent (inconsistent means there are other jobs
# clogging up the system)
ax = ax_all[1]
ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1],
self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b',
ls='-', label="epoch duration", linewidth=4)
ylim = [0] + [ax.get_ylim()[1]]
ax.set(ylim=ylim)
ax.set_xlabel("epoch")
ax.set_ylabel("time [s]")
ax.legend(loc=(0, 1))
# learning rate
ax = ax_all[2]
ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4)
ax.set_xlabel("epoch")
ax.set_ylabel("learning rate")
ax.legend(loc=(0, 1))
plt.tight_layout()
fig.savefig(join(output_folder, "progress.png"))
plt.close()
def get_checkpoint(self):
return self.my_fantastic_logging
def load_checkpoint(self, checkpoint: dict):
self.my_fantastic_logging = checkpoint
class WandbLogger:
"""Weights & Biases logger for nnU-Net training runs.
Environment Variables:
nnUNet_wandb_enabled: Whether W&B logger is enabled (default: 0 -> Disabled)
nnUNet_wandb_project: W&B project name (default: "nnunet").
nnUNet_wandb_mode: W&B mode, e.g. "online" or "offline" (default: "online").
"""
def __init__(self, output_folder, resume):
"""Initialize a W&B run and handle resume behavior.
Args:
output_folder: Directory where W&B run data is stored.
resume: Whether to resume a previous W&B run if present.
verbose: Unused verbosity flag (kept for interface compatibility).
"""
if wandb is None:
raise RuntimeError("W&B is not installed. Please install W&B with 'pip install wandb' before using the WandbLogger.")
self.output_folder = Path(output_folder)
self.resume = resume
self.project = os.getenv("nnUNet_wandb_project", "nnunet")
self.mode = os.getenv("nnUNet_wandb_mode", "online")
wandb_id = None
if (self.output_folder / "wandb").is_dir():
if self.resume:
wandb_dir = self.output_folder / "wandb" / "latest-run"
wandb_filename = next(filename for filename in wandb_dir.iterdir() if filename.suffix == ".wandb")
wandb_id = wandb_filename.stem[4:]
else:
shutil.rmtree(str(self.output_folder / "wandb"))
_resume = "allow" if self.resume else "never"
self.run = wandb.init(project=self.project, dir=str(self.output_folder), id=wandb_id, mode=self.mode, resume=_resume)
self.run.config.update({"JobID": get_cluster_job_id()})
self.wandb_init_step = self.run.step
def update_config(self, config: dict):
"""Update W&B config with training metadata.
Args:
config: Configuration values to merge into the run config.
"""
self.run.config.update(config)
def log(self, key, value, step: int):
"""Log a scalar value to W&B.
Args:
key: Metric or field name.
value: Value to log.
step: Step index (typically epoch).
"""
self.log_summary("current_epoch", step)
if self.resume and step < self.wandb_init_step:
return
self.run.log({key: value}, step=step)
def log_summary(self, key, value):
"""Write a summary value to W&B.
Args:
key: Metric or field name.
value: Summary value to store.
"""
self.run.summary[key] = value
================================================
FILE: nnunetv2/training/loss/__init__.py
================================================
================================================
FILE: nnunetv2/training/loss/compound_losses.py
================================================
import torch
from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss
from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1
from torch import nn
class DC_and_CE_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None,
dice_class=SoftDiceLoss):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param aggregate:
:param square_dice:
:param weight_ce:
:param weight_dice:
"""
super(DC_and_CE_loss, self).__init__()
if ignore_label is not None:
ce_kwargs['ignore_index'] = ignore_label
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.ignore_label = ignore_label
self.ce = RobustCrossEntropyLoss(**ce_kwargs)
self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = target != self.ignore_label
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.where(mask, target, 0)
num_fg = mask.sum()
else:
target_dice = target
mask = None
dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target[:, 0]) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result
class DC_and_BCE_loss(nn.Module):
def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False,
dice_class=MemoryEfficientSoftDiceLoss):
"""
DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
target mut be one hot encoded
IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!!
:param soft_dice_kwargs:
:param bce_kwargs:
:param aggregate:
"""
super(DC_and_BCE_loss, self).__init__()
if use_ignore_label:
bce_kwargs['reduction'] = 'none'
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.use_ignore_label = use_ignore_label
self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
if self.use_ignore_label:
# target is one hot encoded here. invert it so that it is True wherever we can compute the loss
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = (1 - target[:, -1:]).bool()
# remove ignore channel now that we have the mask
# why did we use clone in the past? Should have documented that...
# target_regions = torch.clone(target[:, :-1])
target_regions = target[:, :-1]
else:
target_regions = target
mask = None
dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
target_regions = target_regions.float()
if mask is not None:
ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)
else:
ce_loss = self.ce(net_output, target_regions)
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result
class DC_and_topk_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None):
"""
Weights for CE and Dice do not need to sum to one. You can set whatever you want.
:param soft_dice_kwargs:
:param ce_kwargs:
:param aggregate:
:param square_dice:
:param weight_ce:
:param weight_dice:
"""
super().__init__()
if ignore_label is not None:
ce_kwargs['ignore_index'] = ignore_label
self.weight_dice = weight_dice
self.weight_ce = weight_ce
self.ignore_label = ignore_label
self.ce = TopKLoss(**ce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
"""
target must be b, c, x, y(, z) with c=1
:param net_output:
:param target:
:return:
"""
if self.ignore_label is not None:
assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
'(DC_and_CE_loss)'
mask = (target != self.ignore_label).bool()
# remove ignore label from target, replace with one of the known labels. It doesn't matter because we
# ignore gradients in those areas anyway
target_dice = torch.clone(target)
target_dice[target == self.ignore_label] = 0
num_fg = mask.sum()
else:
target_dice = target
mask = None
dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
if self.weight_dice != 0 else 0
ce_loss = self.ce(net_output, target) \
if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
return result
================================================
FILE: nnunetv2/training/loss/deep_supervision.py
================================================
from torch import nn
class DeepSupervisionWrapper(nn.Module):
def __init__(self, loss, weight_factors=None):
"""
Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of
inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then
applied to each entry like this:
l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ...
If weights are None, all w will be 1.
"""
super(DeepSupervisionWrapper, self).__init__()
assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0"
self.weight_factors = tuple(weight_factors)
self.loss = loss
def forward(self, *args):
assert all([isinstance(i, (tuple, list)) for i in args]), \
f"all args must be either tuple or list, got {[type(i) for i in args]}"
# we could check for equal lengths here as well, but we really shouldn't overdo it with checks because
# this code is executed a lot of times!
if self.weight_factors is None:
weights = (1, ) * len(args[0])
else:
weights = self.weight_factors
return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0])
================================================
FILE: nnunetv2/training/loss/dice.py
================================================
from typing import Callable
import torch
from nnunetv2.utilities.ddp_allgather import AllGatherGrad
from torch import nn
class SoftDiceLoss(nn.Module):
def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
ddp: bool = True, clip_tp: float = None):
"""
"""
super(SoftDiceLoss, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.clip_tp = clip_tp
self.ddp = ddp
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
if self.ddp and self.batch_dice:
tp = AllGatherGrad.apply(tp).sum(0, dtype=torch.float32)
fp = AllGatherGrad.apply(fp).sum(0, dtype=torch.float32)
fn = AllGatherGrad.apply(fn).sum(0, dtype=torch.float32)
if self.clip_tp is not None:
tp = torch.clip(tp, min=self.clip_tp , max=None)
nominator = 2 * tp
denominator = 2 * tp + fp + fn
dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8))
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return -dc
class MemoryEfficientSoftDiceLoss(nn.Module):
def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
ddp: bool = True):
"""
saves 1.6 GB on Dataset017 3d_lowres
"""
super(MemoryEfficientSoftDiceLoss, self).__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.ddp = ddp
def forward(self, x, y, loss_mask=None):
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
# make everything shape (b, c)
axes = tuple(range(2, x.ndim))
with torch.no_grad():
if x.ndim != y.ndim:
y = y.view((y.shape[0], 1, *y.shape[1:]))
if x.shape == y.shape:
# if this is the case then gt is probably already a one hot encoding
y_onehot = y.to(torch.float32)
else:
y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.float32)
y_onehot.scatter_(1, y.long(), 1)
if not self.do_bg:
y_onehot = y_onehot[:, 1:]
sum_gt = y_onehot.sum(axes, dtype=torch.float32) if loss_mask is None else (y_onehot * loss_mask).sum(axes, dtype=torch.float32)
# this one MUST be outside the with torch.no_grad(): context. Otherwise no gradients for you
if not self.do_bg:
x = x[:, 1:]
if loss_mask is None:
intersect = (x * y_onehot).sum(axes, dtype=torch.float32)
sum_pred = x.sum(axes, dtype=torch.float32)
else:
intersect = (x * y_onehot * loss_mask).sum(axes, dtype=torch.float32)
sum_pred = (x * loss_mask).sum(axes, dtype=torch.float32)
if self.batch_dice:
if self.ddp:
intersect = AllGatherGrad.apply(intersect).sum(0, dtype=torch.float32)
sum_pred = AllGatherGrad.apply(sum_pred).sum(0, dtype=torch.float32)
sum_gt = AllGatherGrad.apply(sum_gt).sum(0, dtype=torch.float32)
intersect = intersect.sum(0, dtype=torch.float32)
sum_pred = sum_pred.sum(0, dtype=torch.float32)
sum_gt = sum_gt.sum(0, dtype=torch.float32)
dc = (2 * intersect + self.smooth) / (sum_gt + sum_pred + float(self.smooth)).clamp_min(1e-8)
dc = dc.mean()
return -dc
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes: can be (, ) = no summation
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, net_output.ndim))
with torch.no_grad():
if net_output.ndim != gt.ndim:
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))
if net_output.shape == gt.shape:
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt.to(torch.float32)
else:
y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.float32)
y_onehot.scatter_(1, gt.long(), 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
if mask is not None:
with torch.no_grad():
mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)]))
tp *= mask_here
fp *= mask_here
fn *= mask_here
tn *= mask_here
# benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes
# OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram
# (using nnUNetv2_train 998 3d_fullres 0)
# tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
# fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
# fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
# tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tn = tn ** 2
if len(axes) > 0:
tp = tp.sum(dim=axes, keepdim=False, dtype=torch.float32)
fp = fp.sum(dim=axes, keepdim=False, dtype=torch.float32)
fn = fn.sum(dim=axes, keepdim=False, dtype=torch.float32)
tn = tn.sum(dim=axes, keepdim=False, dtype=torch.float32)
return tp, fp, fn, tn
if __name__ == '__main__':
from nnunetv2.utilities.helpers import softmax_helper_dim1
pred = torch.rand((2, 3, 32, 32, 32))
ref = torch.randint(0, 3, (2, 32, 32, 32))
dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)
dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)
res_old = dl_old(pred, ref)
res_new = dl_new(pred, ref)
print(res_old, res_new)
================================================
FILE: nnunetv2/training/loss/robust_ce_loss.py
================================================
import torch
from torch import nn, Tensor
import numpy as np
class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
"""
this is just a compatibility layer because my target tensor is float and has an extra dimension
input must be logits, not probabilities!
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if target.ndim == input.ndim:
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target.long())
class TopKLoss(RobustCrossEntropyLoss):
"""
input must be logits, not probabilities!
"""
def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0):
self.k = k
super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing)
def forward(self, inp, target):
target = target[:, 0].long()
res = super(TopKLoss, self).forward(inp, target)
num_voxels = np.prod(res.shape, dtype=np.int64)
res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
return res.mean()
================================================
FILE: nnunetv2/training/lr_scheduler/__init__.py
================================================
================================================
FILE: nnunetv2/training/lr_scheduler/polylr.py
================================================
from torch.optim.lr_scheduler import _LRScheduler
class PolyLRScheduler(_LRScheduler):
def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_steps = max_steps
self.exponent = exponent
self.ctr = 0
super().__init__(optimizer, current_step if current_step is not None else -1)
def step(self, current_step=None):
if current_step is None or current_step == -1:
current_step = self.ctr
self.ctr += 1
new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
def get_last_lr(self):
return self._last_lr
================================================
FILE: nnunetv2/training/lr_scheduler/warmup.py
================================================
import math
import warnings
from typing import Optional, cast, List
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, _enable_get_lr_call
class Lin_incr_LRScheduler(_LRScheduler):
def __init__(self, optimizer, max_lr: float, max_steps: int, current_step: int = None):
self.optimizer = optimizer
self.max_lr = max_lr
self.max_steps = max_steps
self.ctr = 0
super().__init__(optimizer, current_step if current_step is not None else -1)
def step(self, current_step=None):
if current_step is None or current_step == -1:
current_step = self.ctr
self.ctr += 1
new_lr = self.max_lr / self.max_steps * (1 + current_step)
for param_group in self.optimizer.param_groups:
param_group["lr"] = new_lr
class Lin_incr_offset_LRScheduler(_LRScheduler):
def __init__(self, optimizer, max_lr: float, max_steps: int, start_step: int, current_step: int = None):
self.optimizer = optimizer
self.max_lr = max_lr
self.max_steps = max_steps
self.start_step = start_step
self.ctr = 0
super().__init__(optimizer, current_step if current_step is not None else -1)
def step(self, current_step=None):
if current_step is None or current_step == -1:
current_step = self.ctr
self.ctr += 1
new_lr = self.max_lr / self.max_steps * (1 + current_step - self.start_step)
for param_group in self.optimizer.param_groups:
param_group["lr"] = new_lr
class PolyLRScheduler_offset(_LRScheduler):
def __init__(
self,
optimizer,
initial_lr: float,
max_steps: int,
start_step: int,
exponent: float = 0.9,
current_step: int = None,
):
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_steps = max_steps - start_step
self.start_step = start_step
self.exponent = exponent
self.ctr = 0
super().__init__(optimizer, current_step if current_step is not None else -1)
def step(self, current_step=None):
if current_step is None or current_step == -1:
current_step = self.ctr
self.ctr += 1
current_step = current_step - self.start_step
if current_step <= 0:
current_step = 0
new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
for param_group in self.optimizer.param_groups:
param_group["lr"] = new_lr
class CosineAnnealingLR_offset(CosineAnnealingLR):
def __init__(
self, optimizer: Optimizer, T_max: int, eta_min=0, last_epoch=-1, verbose="deprecated", offset: int = 0
):
self.offset = offset
super().__init__(
optimizer,
T_max,
eta_min,
last_epoch,
verbose,
)
def _get_closed_form_lr(self):
return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * (self.last_epoch - self.offset) / (self.T_max - self.offset)))
/ 2
for base_lr in self.base_lrs
]
def step(self, epoch: Optional[int] = None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
warnings.warn(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif not getattr(self.optimizer, "_opt_called", False):
warnings.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
self._step_count += 1
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
else:
self.last_epoch = epoch
values = cast(List[float], self._get_closed_form_lr())
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
if isinstance(param_group["lr"], Tensor):
lr_val = lr.item() if isinstance(lr, Tensor) else lr # type: ignore[attr-defined]
param_group["lr"].fill_(lr_val)
else:
param_group["lr"] = lr
self._last_lr: List[float] = [group["lr"] for group in self.optimizer.param_groups]
================================================
FILE: nnunetv2/training/nnUNetTrainer/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
================================================
import inspect
import multiprocessing
import os
import shutil
import sys
import warnings
from copy import deepcopy
from datetime import datetime
from time import time, sleep
from typing import Tuple, Union, List
import numpy as np
import torch
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from torch import autocast, nn
from torch import distributed as dist
from torch._dynamo import OptimizedModule
from torch.cuda import device_count
from torch import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
from nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader
from nnunetv2.training.logging.nnunet_logger import MetaLogger
from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.utilities.crossval_split import generate_crossval_split
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
class nnUNetTrainer(object):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
# From https://grugbrain.dev/. Worth a read ya big brains ;-)
# apex predator of grug is complexity
# complexity bad
# say again:
# complexity very bad
# you say now:
# complexity very, very bad
# given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex
# complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime
# one day code base understandable and grug can get work done, everything good!
# next day impossible: complexity demon spirit has entered code and very dangerous situation!
# OK OK I am guilty. But I tried.
# https://www.osnews.com/images/comics/wtfm.jpg
# https://i.pinimg.com/originals/26/b2/50/26b250a738ea4abc7a5af4d42ad93af0.jpg
self.is_ddp = dist.is_available() and dist.is_initialized()
self.local_rank = 0 if not self.is_ddp else dist.get_rank()
self.device = device
# print what device we are using
if self.is_ddp: # implicitly it's clear that we use cuda in this case
print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is "
f"{dist.get_world_size()}."
f"Setting device to {self.device}")
self.device = torch.device(type='cuda', index=self.local_rank)
else:
if self.device.type == 'cuda':
# we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
self.device = torch.device(type='cuda', index=0)
print(f"Using device: {self.device}")
# loading and saving this class for continuing from checkpoint should not happen based on pickling. This
# would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
# need. So let's save the init args
self.my_init_kwargs = {}
for k in inspect.signature(self.__init__).parameters.keys():
self.my_init_kwargs[k] = locals()[k]
### Saving all the init args into class variables for later access
continue_training = plans.pop("continue_training")
logger_config = {"plans": plans, "configuration": configuration, "fold": fold, "dataset": dataset_json}
self.plans_manager = PlansManager(plans)
self.configuration_manager = self.plans_manager.get_configuration(configuration)
self.configuration_name = configuration
self.dataset_json = dataset_json
self.fold = fold
### Setting all the folder names. We need to make sure things don't crash in case we are just running
# inference and some of the folders may not be defined!
self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \
if nnUNet_preprocessed is not None else None
self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \
if nnUNet_results is not None else None
self.output_folder = join(self.output_folder_base, f'fold_{fold}')
self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,
self.configuration_manager.data_identifier)
self.dataset_class = None # -> initialize
# unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to
# be a different configuration in the same plans
# IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using
# "previous_stage" and "next_stage"). Otherwise it won't work!
self.is_cascaded = self.configuration_manager.previous_stage_name is not None
self.folder_with_segs_from_previous_stage = \
join(nnUNet_results, self.plans_manager.dataset_name,
self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" +
self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \
if self.is_cascaded else None
### Some hyperparameters for you to fiddle with
self.initial_lr = 1e-2
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.probabilistic_oversampling = False
self.num_iterations_per_epoch = 250
self.num_val_iterations_per_epoch = 50
self.num_epochs = 1000
self.current_epoch = 0
self.enable_deep_supervision = True
### Dealing with labels/regions
self.label_manager = self.plans_manager.get_label_manager(dataset_json)
# labels can either be a list of int (regular training) or a list of tuples of int (region-based training)
# needed for predictions. We do sigmoid in case of (overlapping) regions
self.num_input_channels = None # -> self.initialize()
self.network = None # -> self.build_network_architecture()
self.optimizer = self.lr_scheduler = None # -> self.initialize
self.grad_scaler = GradScaler("cuda") if self.device.type == 'cuda' else None
self.loss = None # -> self.initialize
### Simple logging. Don't take that away from me!
# initialize log file. This is just our log for the print statements etc. Not to be confused with lightning
# logging
timestamp = datetime.now()
maybe_mkdir_p(self.output_folder)
self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
(timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
timestamp.second))
self.logger = MetaLogger(self.output_folder, continue_training)
self.logger.update_config(logger_config)
### placeholders
self.dataloader_train = self.dataloader_val = None # see on_train_start
### initializing stuff for remembering things and such
self._best_ema = None
### inference things
self.inference_allowed_mirroring_axes = None # this variable is set in
# self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints
### checkpoint saving stuff
self.save_every = 50
self.disable_checkpointing = False
self.was_initialized = False
self.print_to_log_file("\n#######################################################################\n"
"Please cite the following paper when using nnU-Net:\n"
"Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
"Nature methods, 18(2), 203-211.\n"
"#######################################################################\n",
also_print_to_console=True, add_timestamp=False)
def initialize(self):
if not self.was_initialized:
## DDP batch size and oversampling can differ between workers and needs adaptation
# we need to change the batch size in DDP because we don't use any of those distributed samplers
self._set_batch_size_and_oversample()
self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
self.dataset_json)
self.network = self.build_network_architecture(
self.configuration_manager.network_arch_class_name,
self.configuration_manager.network_arch_init_kwargs,
self.configuration_manager.network_arch_init_kwargs_req_import,
self.num_input_channels,
self.label_manager.num_segmentation_heads,
self.enable_deep_supervision
).to(self.device)
# compile network for free speedup
if self._do_i_compile():
self.print_to_log_file('Using torch.compile...')
self.network = torch.compile(self.network)
self.optimizer, self.lr_scheduler = self.configure_optimizers()
# if ddp, wrap in DDP wrapper
if self.is_ddp:
self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
self.network = DDP(self.network, device_ids=[self.local_rank])
self.loss = self._build_loss()
self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)
# torch 2.2.2 crashes upon compiling CE loss
# if self._do_i_compile():
# self.loss = torch.compile(self.loss)
self.was_initialized = True
logger_config_hparas = {
"initial_lr": self.initial_lr,
"weight_decay": self.weight_decay,
"oversample_foreground_percent": self.oversample_foreground_percent,
"probabilistic_oversampling": self.probabilistic_oversampling,
"num_iterations_per_epoch": self.num_iterations_per_epoch,
"num_val_iterations_per_epoch": self.num_val_iterations_per_epoch,
"num_epochs": self.num_epochs,
"enable_deep_supervision": self.enable_deep_supervision,
"batch_size": self.configuration_manager.batch_size
}
self.logger.update_config({"hparas": logger_config_hparas})
else:
raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
"That should not happen.")
def _do_i_compile(self):
# new default: compile is enabled!
# compile does not work on mps
if self.device == torch.device('mps'):
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device")
return False
# CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable
if self.device == torch.device('cpu'):
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because device is CPU")
return False
# default torch.compile doesn't work on windows because there are apparently no triton wheels for it
# https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2
if os.name == 'nt':
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If "
"you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2")
return False
if 'nnUNet_compile' not in os.environ.keys():
return True
else:
return os.environ['nnUNet_compile'].lower() in ('true', '1', 't')
def _save_debug_information(self):
# saving some debug information
if self.local_rank == 0:
dct = {}
for k in self.__dir__():
if not k.startswith("__"):
if not callable(getattr(self, k)) or k in ['loss', ]:
dct[k] = str(getattr(self, k))
elif k in ['network', ]:
dct[k] = str(getattr(self, k).__class__.__name__)
else:
# print(k)
pass
if k in ['dataloader_train', 'dataloader_val']:
dl = getattr(self, k)
if hasattr(dl, 'generator'):
dct[k + '.generator'] = str(dl.generator)
if hasattr(dl.generator, 'transforms'):
try:
dct[k + '.generator.transforms'] = str(dl.generator.transforms)
except Exception as e:
dct[k + '.generator.transforms'] = f"Could not stringify generator.transforms: {type(e).__name__}: {e}"
if hasattr(dl, 'num_processes'):
dct[k + '.num_processes'] = str(dl.num_processes)
if hasattr(dl, 'transform'):
dct[k + '.transform'] = str(dl.transform)
import subprocess
hostname = subprocess.getoutput(['hostname'])
dct['hostname'] = hostname
torch_version = torch.__version__
if self.device.type == 'cuda':
gpu_name = torch.cuda.get_device_name()
dct['gpu_name'] = gpu_name
cudnn_version = torch.backends.cudnn.version()
else:
cudnn_version = 'None'
dct['device'] = str(self.device)
dct['torch_version'] = torch_version
dct['cudnn_version'] = cudnn_version
save_json(dct, join(self.output_folder, "debug.json"))
@staticmethod
def build_network_architecture(architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True) -> nn.Module:
"""
This is where you build the architecture according to the plans. There is no obligation to use
get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
you want. Even ignore the plans and just return something static (as long as it can process the requested
patch size)
but don't bug us with your bugs arising from fiddling with this :-P
This is the function that is called in inference as well! This is needed so that all network architecture
variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
training, so if you change the network architecture during training by deriving a new trainer class then
inference will know about it).
If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
> label_manager = plans_manager.get_label_manager(dataset_json)
> label_manager.num_segmentation_heads
(why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
the number of outputs is != the number of classes. Also there is the ignore label for which no output
should be generated. label_manager takes care of all that for you.)
"""
return get_network_from_plans(
architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels,
allow_init=True,
deep_supervision=enable_deep_supervision)
def _get_deep_supervision_scales(self):
if self.enable_deep_supervision:
deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(
self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]
else:
deep_supervision_scales = None # for train and val_transforms
return deep_supervision_scales
def _set_batch_size_and_oversample(self):
if not self.is_ddp:
# set batch size to what the plan says, leave oversample untouched
self.batch_size = self.configuration_manager.batch_size
else:
# batch size is distributed over DDP workers and we need to change oversample_percent for each worker
world_size = dist.get_world_size()
my_rank = dist.get_rank()
global_batch_size = self.configuration_manager.batch_size
assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
'GPUs... Duh.'
batch_size_per_GPU = [global_batch_size // world_size] * world_size
batch_size_per_GPU = [batch_size_per_GPU[i] + 1
if (batch_size_per_GPU[i] * world_size + i) < global_batch_size
else batch_size_per_GPU[i]
for i in range(len(batch_size_per_GPU))]
assert sum(batch_size_per_GPU) == global_batch_size
sample_id_low = 0 if my_rank == 0 else np.sum(batch_size_per_GPU[:my_rank])
sample_id_high = np.sum(batch_size_per_GPU[:my_rank + 1])
# This is how oversampling is determined in DataLoader
# round(self.batch_size * (1 - self.oversample_foreground_percent))
# We need to use the same scheme here because an oversample of 0.33 with a batch size of 2 will be rounded
# to an oversample of 0.5 (1 sample random, one oversampled). This may get lost if we just numerically
# compute oversample
oversample = [True if not i < round(global_batch_size * (1 - self.oversample_foreground_percent)) else False
for i in range(global_batch_size)]
if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):
oversample_percent = 0.0
elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):
oversample_percent = 1.0
else:
oversample_percent = sum(oversample[sample_id_low:sample_id_high]) / batch_size_per_GPU[my_rank]
print("worker", my_rank, "oversample", oversample_percent)
print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank])
self.batch_size = batch_size_per_GPU[my_rank]
self.oversample_foreground_percent = oversample_percent
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{'batch_dice': self.configuration_manager.batch_dice,
'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
if self._do_i_compile():
loss.dc = torch.compile(loss.dc)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
if self.is_ddp and not self._do_i_compile():
# very strange and stupid interaction. DDP crashes and complains about unused parameters due to
# weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff.
# Anywho, the simple fix is to set a very low weight to this.
weights[-1] = 1e-6
else:
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
"""
This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
"""
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
# todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
if dim == 2:
do_dummy_2d_data_aug = False
# todo revisit this parametrization
if max(patch_size) / min(patch_size) > 1.5:
rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
else:
rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
mirror_axes = (0, 1)
elif dim == 3:
# todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
# order of the axes is determined by spacing, not image size
do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
if do_dummy_2d_data_aug:
# why do we rotate 180 deg here all the time? We should also restrict it
rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
else:
rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
mirror_axes = (0, 1, 2)
else:
raise RuntimeError()
# todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
# old nnunet for now)
initial_patch_size = get_patch_size(patch_size[-dim:],
rotation_for_DA,
rotation_for_DA,
rotation_for_DA,
(0.85, 1.25))
if do_dummy_2d_data_aug:
initial_patch_size[0] = patch_size[0]
self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
self.inference_allowed_mirroring_axes = mirror_axes
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
if self.local_rank == 0:
timestamp = time()
dt_object = datetime.fromtimestamp(timestamp)
if add_timestamp:
args = (f"{dt_object}:", *args)
successful = False
max_attempts = 5
ctr = 0
while not successful and ctr < max_attempts:
try:
with open(self.log_file, 'a+') as f:
for a in args:
f.write(str(a))
f.write(" ")
f.write("\n")
successful = True
except IOError:
print(f"{datetime.fromtimestamp(timestamp)}: failed to log: ", sys.exc_info())
sleep(0.5)
ctr += 1
if also_print_to_console:
print(*args)
elif also_print_to_console:
print(*args)
def print_plans(self):
if self.local_rank == 0:
dct = deepcopy(self.plans_manager.plans)
del dct['configurations']
self.print_to_log_file(f"\nThis is the configuration used by this "
f"training:\nConfiguration name: {self.configuration_name}\n",
self.configuration_manager, '\n', add_timestamp=False)
self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
return optimizer, lr_scheduler
def plot_network_architecture(self):
if self._do_i_compile():
self.print_to_log_file("Unable to plot network architecture: nnUNet_compile is enabled!")
return
if self.local_rank == 0:
try:
# raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')
# pip install git+https://github.com/saugatkandel/hiddenlayer.git
# from torchviz import make_dot
# # not viable.
# make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,
# *self.configuration_manager.patch_size),
# device=self.device)))).render(
# join(self.output_folder, "network_architecture.pdf"), format='pdf')
# self.optimizer.zero_grad()
# broken.
import hiddenlayer as hl
g = hl.build_graph(self.network,
torch.rand((1, self.num_input_channels,
*self.configuration_manager.patch_size),
device=self.device),
transforms=None)
g.save(join(self.output_folder, "network_architecture.pdf"))
del g
except Exception as e:
self.print_to_log_file("Unable to plot network architecture:")
self.print_to_log_file(e)
# self.print_to_log_file("\nprinting the network instead:\n")
# self.print_to_log_file(self.network)
# self.print_to_log_file("\n")
finally:
empty_cache(self.device)
def do_split(self):
"""
The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
so always the same) and save it as splits_final.json file in the preprocessed data directory.
Sometimes you may want to create your own split for various reasons. For this you will need to create your own
splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
use a random 80:20 data split.
:return:
"""
if self.dataset_class is None:
self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)
if self.fold == "all":
# if fold==all then we use all images for training and validation
case_identifiers = self.dataset_class.get_identifiers(self.preprocessed_dataset_folder)
tr_keys = case_identifiers
val_keys = tr_keys
else:
splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
dataset = self.dataset_class(self.preprocessed_dataset_folder,
identifiers=None,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
all_keys_sorted = list(np.sort(list(dataset.identifiers)))
splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5)
save_json(splits, splits_file)
else:
self.print_to_log_file("Using splits from existing split file:", splits_file)
splits = load_json(splits_file)
self.print_to_log_file(f"The split file contains {len(splits)} splits.")
self.print_to_log_file("Desired fold for training: %d" % self.fold)
if self.fold < len(splits):
tr_keys = splits[self.fold]['train']
val_keys = splits[self.fold]['val']
self.print_to_log_file("This split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
else:
self.print_to_log_file("INFO: You requested fold %d for training but splits "
"contain only %d folds. I am now creating a "
"random (but seeded) 80:20 split!" % (self.fold, len(splits)))
# if we request a fold that is not in the split file, create a random 80:20 split
rnd = np.random.RandomState(seed=12345 + self.fold)
keys = np.sort(list(dataset.identifiers))
idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
idx_val = [i for i in range(len(keys)) if i not in idx_tr]
tr_keys = [keys[i] for i in idx_tr]
val_keys = [keys[i] for i in idx_val]
self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
% (len(tr_keys), len(val_keys)))
if any([i in val_keys for i in tr_keys]):
self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '
'splits.json or ignore if this is intentional.')
return tr_keys, val_keys
def get_tr_and_val_datasets(self):
# create dataset split
tr_keys, val_keys = self.do_split()
# load the datasets for training and validation. Note that we always draw random samples so we really don't
# care about distributing training cases across GPUs.
dataset_tr = self.dataset_class(self.preprocessed_dataset_folder, tr_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
return dataset_tr, dataset_val
def get_dataloaders(self):
if self.dataset_class is None:
self.dataset_class = infer_dataset_class(self.preprocessed_dataset_folder)
# we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
# we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
patch_size = self.configuration_manager.patch_size
# needed for deep supervision: how much do we need to downscale the segmentation targets for the different
# outputs?
deep_supervision_scales = self._get_deep_supervision_scales()
(
rotation_for_DA,
do_dummy_2d_data_aug,
initial_patch_size,
mirror_axes,
) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
# training pipeline
tr_transforms = self.get_training_transforms(
patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
# validation pipeline
val_transforms = self.get_validation_transforms(deep_supervision_scales,
is_cascaded=self.is_cascaded,
foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if
self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
dataset_tr, dataset_val = self.get_tr_and_val_datasets()
dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size,
initial_patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None, transforms=tr_transforms,
probabilistic_oversampling=self.probabilistic_oversampling)
dl_val = nnUNetDataLoader(dataset_val, self.batch_size,
self.configuration_manager.patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None, transforms=val_transforms,
probabilistic_oversampling=self.probabilistic_oversampling)
allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, None)
mt_gen_val = SingleThreadedAugmenter(dl_val, None)
else:
mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=None,
num_processes=allowed_num_processes,
num_cached=max(6, allowed_num_processes // 2), seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.002)
mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val,
transform=None, num_processes=max(1, allowed_num_processes // 2),
num_cached=max(3, allowed_num_processes // 4), seeds=None,
pin_memory=self.device.type == 'cuda',
wait_time=0.002)
# # let's get this party started
_ = next(mt_gen_train)
_ = next(mt_gen_val)
return mt_gen_train, mt_gen_val
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
transforms = []
if do_dummy_2d_data_aug:
ignore_axes = (0,)
transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
transforms.append(
SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
p_rotation=0.2,
rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,
bg_style_seg_sampling=False # , mode_seg='nearest'
)
)
if do_dummy_2d_data_aug:
transforms.append(Convert2DTo3DTransform())
transforms.append(RandomTransform(
GaussianNoiseTransform(
noise_variance=(0, 0.1),
p_per_channel=1,
synchronize_channels=True
), apply_probability=0.1
))
transforms.append(RandomTransform(
GaussianBlurTransform(
blur_sigma=(0.5, 1.),
synchronize_channels=False,
synchronize_axes=False,
p_per_channel=0.5, benchmark=True
), apply_probability=0.2
))
transforms.append(RandomTransform(
MultiplicativeBrightnessTransform(
multiplier_range=BGContrast((0.75, 1.25)),
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
ContrastTransform(
contrast_range=BGContrast((0.75, 1.25)),
preserve_range=True,
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
SimulateLowResolutionTransform(
scale=(0.5, 1),
synchronize_channels=False,
synchronize_axes=True,
ignore_axes=ignore_axes,
allowed_channels=None,
p_per_channel=0.5
), apply_probability=0.25
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=1,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.1
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=0,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.3
))
if mirror_axes is not None and len(mirror_axes) > 0:
transforms.append(
MirrorTransform(
allowed_axes=mirror_axes
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
transforms.append(MaskImageTransform(
apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
channel_idx_in_seg=0,
set_outside_to=0,
))
transforms.append(
RemoveLabelTansform(-1, 0)
)
if is_cascaded:
assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
transforms.append(
MoveSegAsOneHotToDataTransform(
source_channel_idx=1,
all_labels=foreground_labels,
remove_channel_from_source=True
)
)
transforms.append(
RandomTransform(
ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
strel_size=(1, 8),
p_per_label=1
), apply_probability=0.4
)
)
transforms.append(
RandomTransform(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15,
p_per_label=1
), apply_probability=0.2
)
)
if regions is not None:
# the ignore label must also be converted
transforms.append(
ConvertSegmentationToRegionsTransform(
regions=list(regions) + [ignore_label] if ignore_label is not None else regions,
channel_in_seg=0
)
)
if deep_supervision_scales is not None:
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
return ComposeTransforms(transforms)
@staticmethod
def get_validation_transforms(
deep_supervision_scales: Union[List, Tuple, None],
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
transforms = []
transforms.append(
RemoveLabelTansform(-1, 0)
)
if is_cascaded:
transforms.append(
MoveSegAsOneHotToDataTransform(
source_channel_idx=1,
all_labels=foreground_labels,
remove_channel_from_source=True
)
)
if regions is not None:
# the ignore label must also be converted
transforms.append(
ConvertSegmentationToRegionsTransform(
regions=list(regions) + [ignore_label] if ignore_label is not None else regions,
channel_in_seg=0
)
)
if deep_supervision_scales is not None:
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
return ComposeTransforms(transforms)
def set_deep_supervision_enabled(self, enabled: bool):
"""
This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
chances you need to change this as well!
"""
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
if isinstance(mod, OptimizedModule):
mod = mod._orig_mod
mod.decoder.deep_supervision = enabled
def on_train_start(self):
if not self.was_initialized:
self.initialize()
# dataloaders must be instantiated here (instead of __init__) because they need access to the training data
# which may not be present when doing inference
self.dataloader_train, self.dataloader_val = self.get_dataloaders()
maybe_mkdir_p(self.output_folder)
# make sure deep supervision is on in the network
self.set_deep_supervision_enabled(self.enable_deep_supervision)
self.print_plans()
empty_cache(self.device)
# maybe unpack
if self.local_rank == 0:
self.dataset_class.unpack_dataset(
self.preprocessed_dataset_folder,
overwrite_existing=False,
num_processes=max(1, round(get_allowed_n_proc_DA() // 2)),
verify=True)
if self.is_ddp:
dist.barrier()
# copy plans and dataset.json so that they can be used for restoring everything we need for inference
save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
# we don't really need the fingerprint but its still handy to have it with the others
shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
join(self.output_folder_base, 'dataset_fingerprint.json'))
# produces a pdf in output folder
self.plot_network_architecture()
self._save_debug_information()
# print(f"batch size: {self.batch_size}")
# print(f"oversample: {self.oversample_foreground_percent}")
def on_train_end(self):
# dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards.
# This will lead to the wrong current epoch to be stored
self.current_epoch -= 1
self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
self.current_epoch += 1
# now we can delete latest
if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
os.remove(join(self.output_folder, "checkpoint_latest.pth"))
# shut down dataloaders
old_stdout = sys.stdout
with open(os.devnull, 'w') as f:
sys.stdout = f
if self.dataloader_train is not None and \
isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):
self.dataloader_train._finish()
if self.dataloader_val is not None and \
isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):
self.dataloader_val._finish()
sys.stdout = old_stdout
empty_cache(self.device)
self.print_to_log_file("Training done.")
def on_train_epoch_start(self):
self.network.train()
self.lr_scheduler.step(self.current_epoch)
self.print_to_log_file('')
self.print_to_log_file(f'Epoch {self.current_epoch}')
self.print_to_log_file(
f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")
# lrs are the same for all workers so we don't need to gather them in case of DDP training
self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
def train_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
output = self.network(data)
# del data
l = self.loss(output, target)
if self.grad_scaler is not None:
self.grad_scaler.scale(l).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
return {'loss': l.detach().cpu().numpy()}
def on_train_epoch_end(self, train_outputs: List[dict]):
outputs = collate_outputs(train_outputs)
if self.is_ddp:
losses_tr = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(losses_tr, outputs['loss'])
loss_here = np.vstack(losses_tr).mean()
else:
loss_here = np.mean(outputs['loss'])
self.logger.log('train_losses', loss_here, self.current_epoch)
def on_validation_epoch_start(self):
self.network.eval()
def validation_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
output = self.network(data)
del data
l = self.loss(output, target)
# we only need the output with the highest output resolution (if DS enabled)
if self.enable_deep_supervision:
output = output[0]
target = target[0]
# the following is needed for online evaluation. Fake dice (green line)
axes = [0] + list(range(2, output.ndim))
if self.label_manager.has_regions:
predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
else:
# no need for softmax
output_seg = output.argmax(1)[:, None]
predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16)
predicted_segmentation_onehot.scatter_(1, output_seg, 1)
del output_seg
if self.label_manager.has_ignore_label:
if not self.label_manager.has_regions:
mask = (target != self.label_manager.ignore_label).float()
# CAREFUL that you don't rely on target after this line!
target[target == self.label_manager.ignore_label] = 0
else:
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = 1 - target[:, -1:]
# CAREFUL that you don't rely on target after this line!
target = target[:, :-1]
else:
mask = None
tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
tp_hard = tp.detach().cpu().numpy()
fp_hard = fp.detach().cpu().numpy()
fn_hard = fn.detach().cpu().numpy()
if not self.label_manager.has_regions:
# if we train with regions all segmentation heads predict some kind of foreground. In conventional
# (softmax training) there needs tobe one output for the background. We are not interested in the
# background Dice
# [1:] in order to remove background
tp_hard = tp_hard[1:]
fp_hard = fp_hard[1:]
fn_hard = fn_hard[1:]
return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
def on_validation_epoch_end(self, val_outputs: List[dict]):
outputs_collated = collate_outputs(val_outputs)
tp = np.sum(outputs_collated['tp_hard'], 0)
fp = np.sum(outputs_collated['fp_hard'], 0)
fn = np.sum(outputs_collated['fn_hard'], 0)
if self.is_ddp:
world_size = dist.get_world_size()
tps = [None for _ in range(world_size)]
dist.all_gather_object(tps, tp)
tp = np.vstack([i[None] for i in tps]).sum(0)
fps = [None for _ in range(world_size)]
dist.all_gather_object(fps, fp)
fp = np.vstack([i[None] for i in fps]).sum(0)
fns = [None for _ in range(world_size)]
dist.all_gather_object(fns, fn)
fn = np.vstack([i[None] for i in fns]).sum(0)
losses_val = [None for _ in range(world_size)]
dist.all_gather_object(losses_val, outputs_collated['loss'])
loss_here = np.vstack(losses_val).mean()
else:
loss_here = np.mean(outputs_collated['loss'])
global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]]
mean_fg_dice = np.nanmean(global_dc_per_class)
self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
self.logger.log('val_losses', loss_here, self.current_epoch)
def on_epoch_start(self):
self.logger.log('epoch_start_timestamps', time(), self.current_epoch)
def on_epoch_end(self):
self.logger.log('epoch_end_timestamps', time(), self.current_epoch)
self.print_to_log_file('train_loss', np.round(self.logger.get_value('train_losses', step=-1), decimals=4))
self.print_to_log_file('val_loss', np.round(self.logger.get_value('val_losses', step=-1), decimals=4))
self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
self.logger.get_value('dice_per_class_or_region', step=-1)])
self.print_to_log_file(
f"Epoch time: {np.round(self.logger.get_value('epoch_end_timestamps', step=-1) - self.logger.get_value('epoch_start_timestamps', step=-1), decimals=2)} s")
# handling periodic checkpointing
current_epoch = self.current_epoch
if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))
# handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
if self._best_ema is None or self.logger.get_value('ema_fg_dice', step=-1) > self._best_ema:
self._best_ema = self.logger.get_value('ema_fg_dice', step=-1)
self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))
if self.local_rank == 0:
self.logger.plot_progress_png(self.output_folder)
self.current_epoch += 1
def save_checkpoint(self, filename: str) -> None:
if self.local_rank == 0:
if not self.disable_checkpointing:
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
if isinstance(mod, OptimizedModule):
mod = mod._orig_mod
checkpoint = {
'network_weights': mod.state_dict(),
'optimizer_state': self.optimizer.state_dict(),
'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,
'logging': self.logger.get_checkpoint(),
'_best_ema': self._best_ema,
'current_epoch': self.current_epoch + 1,
'init_args': self.my_init_kwargs,
'trainer_name': self.__class__.__name__,
'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,
}
torch.save(checkpoint, filename)
else:
self.print_to_log_file('No checkpoint written, checkpointing is disabled')
def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
if not self.was_initialized:
self.initialize()
if isinstance(filename_or_checkpoint, str):
checkpoint = torch.load(filename_or_checkpoint, map_location=self.device, weights_only=False)
# if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
new_state_dict = {}
for k, value in checkpoint['network_weights'].items():
key = k
if key not in self.network.state_dict().keys() and key.startswith('module.'):
key = key[7:]
new_state_dict[key] = value
self.my_init_kwargs = checkpoint['init_args']
self.current_epoch = checkpoint['current_epoch']
self.logger.load_checkpoint(checkpoint['logging'])
self._best_ema = checkpoint['_best_ema']
self.inference_allowed_mirroring_axes = checkpoint[
'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes
# messing with state dict naming schemes. Facepalm.
if self.is_ddp:
if isinstance(self.network.module, OptimizedModule):
self.network.module._orig_mod.load_state_dict(new_state_dict)
else:
self.network.module.load_state_dict(new_state_dict)
else:
if isinstance(self.network, OptimizedModule):
self.network._orig_mod.load_state_dict(new_state_dict)
else:
self.network.load_state_dict(new_state_dict)
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
if self.grad_scaler is not None:
if checkpoint['grad_scaler_state'] is not None:
self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])
def perform_actual_validation(self, save_probabilities: bool = False):
self.set_deep_supervision_enabled(False)
self.network.eval()
if self.is_ddp and self.batch_size == 1 and self.enable_deep_supervision and self._do_i_compile():
self.print_to_log_file("WARNING! batch size is 1 during training and torch.compile is enabled. If you "
"encounter crashes in validation then this is because torch.compile forgets "
"to trigger a recompilation of the model with deep supervision disabled. "
"This causes torch.flip to complain about getting a tuple as input. Just rerun the "
"validation with --val (exactly the same as before) and then it will work. "
"Why? Because --val triggers nnU-Net to ONLY run validation meaning that the first "
"forward pass (where compile is triggered) already has deep supervision disabled. "
"This is exactly what we need in perform_actual_validation")
predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
perform_everything_on_device=True, device=self.device, verbose=False,
verbose_preprocessing=False, allow_tqdm=False)
predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
self.dataset_json, self.__class__.__name__,
self.inference_allowed_mirroring_axes)
with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool:
worker_list = [i for i in segmentation_export_pool._pool]
validation_output_folder = join(self.output_folder, 'validation')
maybe_mkdir_p(validation_output_folder)
# we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute
# the validation keys across the workers.
_, val_keys = self.do_split()
if self.is_ddp:
last_barrier_at_idx = len(val_keys) // dist.get_world_size() - 1
val_keys = val_keys[self.local_rank:: dist.get_world_size()]
# we cannot just have barriers all over the place because the number of keys each GPU receives can be
# different
dataset_val = self.dataset_class(self.preprocessed_dataset_folder, val_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
next_stages = self.configuration_manager.next_stage_names
if next_stages is not None:
_ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]
results = []
for i, k in enumerate(dataset_val.identifiers):
proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
allowed_num_queued=2)
while not proceed:
sleep(0.1)
proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
allowed_num_queued=2)
self.print_to_log_file(f"predicting {k}")
data, _, seg_prev, properties = dataset_val.load_case(k)
# we do [:] to convert blosc2 to numpy
data = data[:]
if self.is_cascaded:
seg_prev = seg_prev[:]
data = np.vstack((data, convert_labelmap_to_one_hot(seg_prev, self.label_manager.foreground_labels,
output_dtype=data.dtype)))
with warnings.catch_warnings():
# ignore 'The given NumPy array is not writable' warning
warnings.simplefilter("ignore")
data = torch.from_numpy(data)
self.print_to_log_file(f'{k}, shape {data.shape}, rank {self.local_rank}')
output_filename_truncated = join(validation_output_folder, k)
prediction = predictor.predict_sliding_window_return_logits(data)
prediction = prediction.cpu()
# this needs to go into background processes
results.append(
segmentation_export_pool.starmap_async(
export_prediction_from_logits, (
(prediction, properties, self.configuration_manager, self.plans_manager,
self.dataset_json, output_filename_truncated, save_probabilities),
)
)
)
# for debug purposes
# export_prediction_from_logits(
# prediction, properties, self.configuration_manager, self.plans_manager,
# self.dataset_json, output_filename_truncated, save_probabilities
# )
# if needed, export the softmax prediction for the next stage
if next_stages is not None:
for n in next_stages:
next_stage_config_manager = self.plans_manager.get_configuration(n)
expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,
next_stage_config_manager.data_identifier)
# next stage may have a different dataset class, do not use self.dataset_class
dataset_class = infer_dataset_class(expected_preprocessed_folder)
try:
# we do this so that we can use load_case and do not have to hard code how loading training cases is implemented
tmp = dataset_class(expected_preprocessed_folder, [k])
d, _, _, _ = tmp.load_case(k)
except FileNotFoundError:
self.print_to_log_file(
f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! "
f"Run the preprocessing for this configuration first!")
continue
target_shape = d.shape[1:]
output_folder = join(self.output_folder_base, 'predicted_next_stage', n)
output_file_truncated = join(output_folder, k)
# resample_and_save(prediction, target_shape, output_file_truncated, self.plans_manager,
# self.configuration_manager,
# properties,
# self.dataset_json,
# default_num_processes,
# dataset_class)
results.append(segmentation_export_pool.starmap_async(
resample_and_save, (
(prediction, target_shape, output_file_truncated, self.plans_manager,
self.configuration_manager,
properties,
self.dataset_json,
default_num_processes,
dataset_class),
)
))
# if we don't barrier from time to time we will get nccl timeouts for large datasets. Yuck.
if self.is_ddp and i < last_barrier_at_idx and (i + 1) % 20 == 0:
dist.barrier()
_ = [r.get() for r in results]
if self.is_ddp:
dist.barrier()
if self.local_rank == 0:
metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),
validation_output_folder,
join(validation_output_folder, 'summary.json'),
self.plans_manager.image_reader_writer_class(),
self.dataset_json["file_ending"],
self.label_manager.foreground_regions if self.label_manager.has_regions else
self.label_manager.foreground_labels,
self.label_manager.ignore_label, chill=True,
num_processes=default_num_processes * dist.get_world_size() if
self.is_ddp else default_num_processes)
for label in metrics["mean"]:
self.logger.log_summary(f"final_val/class_{label}_dice", metrics["mean"][label]["Dice"])
self.logger.log_summary("final_val/foreground_dice", metrics['foreground_mean']["Dice"])
self.print_to_log_file("Validation complete", also_print_to_console=True)
self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]),
also_print_to_console=True)
self.set_deep_supervision_enabled(True)
compute_gaussian.cache_clear()
def run_training(self):
self.on_train_start()
for epoch in range(self.current_epoch, self.num_epochs):
self.on_epoch_start()
self.on_train_epoch_start()
train_outputs = []
for batch_id in range(self.num_iterations_per_epoch):
train_outputs.append(self.train_step(next(self.dataloader_train)))
self.on_train_epoch_end(train_outputs)
with torch.no_grad():
self.on_validation_epoch_start()
val_outputs = []
for batch_id in range(self.num_val_iterations_per_epoch):
val_outputs.append(self.validation_step(next(self.dataloader_val)))
self.on_validation_epoch_end(val_outputs)
self.on_epoch_end()
self.on_train_end()
================================================
FILE: nnunetv2/training/nnUNetTrainer/primus/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/primus/primus_trainers.py
================================================
from abc import abstractmethod
from typing import List, Tuple, Union
import torch
from torch import nn, autocast
from dynamic_network_architectures.architectures.primus import Primus
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.training.nnUNetTrainer.variants.lr_schedule.nnUNetTrainer_warmup import nnUNetTrainer_warmup
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset
from nnunetv2.utilities.helpers import empty_cache, dummy_context
######################################################
# See this paper for information on Primus!
# Wald*, T., Roy*, S., Isensee*, F., Ulrich, C., Ziegler, S., Trofimova, D., ... & Maier-Hein, K. (2025). Primus: Enforcing attention usage for 3d medical image segmentation. arXiv preprint arXiv:2503.01835.
# * equal contribution
######################################################
class AbstractPrimus(nnUNetTrainer_warmup):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
self.weight_decay = 5e-2
self.enable_deep_supervision = False
@abstractmethod
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
raise NotImplementedError()
def configure_optimizers(self, stage: str = "warmup_all"):
assert stage in ["warmup_all", "train"]
if self.training_stage == stage:
return self.optimizer, self.lr_scheduler
if isinstance(self.network, DDP):
params = self.network.module.parameters()
else:
params = self.network.parameters()
if stage == "warmup_all":
self.print_to_log_file("train whole net, warmup")
optimizer = torch.optim.AdamW(
params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True
)
lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)
self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}")
else:
self.print_to_log_file("train whole net, default schedule")
if self.training_stage == "warmup_all":
# we can keep the existing optimizer and don't need to create a new one. This will allow us to keep
# the accumulated momentum terms which already point in a useful driection
optimizer = self.optimizer
else:
optimizer = torch.optim.AdamW(
params,
self.initial_lr,
weight_decay=self.weight_decay,
amsgrad=False,
betas=(0.9, 0.98),
fused=True,
)
lr_scheduler = PolyLRScheduler_offset(
optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net
)
self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}")
self.training_stage = stage
empty_cache(self.device)
return optimizer, lr_scheduler
def train_step(self, batch: dict) -> dict:
data = batch["data"]
target = batch["target"]
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
output = self.network(data)
# del data
l = self.loss(output, target)
if self.grad_scaler is not None:
self.grad_scaler.scale(l).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
self.optimizer.step()
return {"loss": l.detach().cpu().numpy()}
def set_deep_supervision_enabled(self, enabled: bool):
pass
class nnUNet_Primus_S_Trainer(AbstractPrimus):
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
# this architecture will crash if the patch size is not divisible by 8!
model = Primus(
num_input_channels,
396,
(8, 8, 8),
num_output_channels,
12,
6,
self.configuration_manager.patch_size,
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model
class nnUNet_Primus_B_Trainer(AbstractPrimus):
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
# this architecture will crash if the patch size is not divisible by 8!
model = Primus(
num_input_channels,
792,
(8, 8, 8),
num_output_channels,
12,
12,
self.configuration_manager.patch_size,
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model
class nnUNet_Primus_M_Trainer(AbstractPrimus):
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
# this architecture will crash if the patch size is not divisible by 8!
model = Primus(
num_input_channels,
864,
(8, 8, 8),
num_output_channels,
16,
12,
self.configuration_manager.patch_size,
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model
class nnUNet_Primus_M_Trainer_BS8(nnUNet_Primus_M_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.configuration_manager.configuration["batch_size"] = 8
class nnUNet_Primus_M_Trainer_BS8_2e4(nnUNet_Primus_M_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 2e-4
self.configuration_manager.configuration["batch_size"] = 8
class nnUNet_Trainer_BS8(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.configuration_manager.configuration["batch_size"] = 8
class nnUNet_Primus_L_Trainer(AbstractPrimus):
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
# this architecture will crash if the patch size is not divisible by 8!
model = Primus(
num_input_channels,
1056,
(8, 8, 8),
num_output_channels,
24,
16,
self.configuration_manager.patch_size,
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model
class _Primus_S_96_BS1(nnUNet_Primus_S_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository
plans["configurations"][configuration]["batch_size"] = 1
super().__init__(plans, configuration, fold, dataset_json, device)
class _Primus_B_96_BS1(nnUNet_Primus_B_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository
plans["configurations"][configuration]["batch_size"] = 1
super().__init__(plans, configuration, fold, dataset_json, device)
class _Primus_M_96_BS1(nnUNet_Primus_M_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
plans["configurations"][configuration]["patch_size"] = (96, 96, 96) # As per repository
plans["configurations"][configuration]["batch_size"] = 1
super().__init__(plans, configuration, fold, dataset_json, device)
class _Primus_L_48_BS1(nnUNet_Primus_L_Trainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
plans["configurations"][configuration]["patch_size"] = (48, 48, 48) # As per repository
plans["configurations"][configuration]["batch_size"] = 1
super().__init__(plans, configuration, fold, dataset_json, device)
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py
================================================
import subprocess
import torch
from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from torch import distributed as dist
class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results."
self.disable_checkpointing = True
self.num_epochs = 5
assert torch.cuda.is_available(), "This only works on GPU"
self.crashed_with_runtime_error = False
def perform_actual_validation(self, save_probabilities: bool = False):
pass
def save_checkpoint(self, filename: str) -> None:
# do not trust people to remember that self.disable_checkpointing must be True for this trainer
pass
def run_training(self):
try:
super().run_training()
except RuntimeError:
self.crashed_with_runtime_error = True
self.on_train_end()
def on_train_end(self):
super().on_train_end()
if not self.is_ddp or self.local_rank == 0:
torch_version = torch.__version__
cudnn_version = torch.backends.cudnn.version()
gpu_name = torch.cuda.get_device_name()
if self.crashed_with_runtime_error:
fastest_epoch = 'Not enough VRAM!'
else:
epoch_times = [i - j for i, j in zip(self.logger.get_value('epoch_end_timestamps', step=None),
self.logger.get_value('epoch_start_timestamps', step=None))]
fastest_epoch = min(epoch_times)
if self.is_ddp:
num_gpus = dist.get_world_size()
else:
num_gpus = 1
benchmark_result_file = join(self.output_folder, 'benchmark_result.json')
if isfile(benchmark_result_file):
old_results = load_json(benchmark_result_file)
else:
old_results = {}
# generate some unique key
hostname = subprocess.getoutput('hostname')
my_key = f"{hostname}__{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__num_gpus_{num_gpus}"
old_results[my_key] = {
'torch_version': torch_version,
'cudnn_version': cudnn_version,
'gpu_name': gpu_name,
'fastest_epoch': fastest_epoch,
'num_gpus': num_gpus,
'hostname': hostname
}
save_json(old_results,
join(self.output_folder, 'benchmark_result.json'))
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py
================================================
import torch
from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import (
nnUNetTrainerBenchmark_5epochs,
)
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self._set_batch_size_and_oversample()
num_input_channels = determine_num_input_channels(
self.plans_manager, self.configuration_manager, self.dataset_json
)
patch_size = self.configuration_manager.patch_size
dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device)
if self.enable_deep_supervision:
dummy_target = [
torch.round(
torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device)
* max(self.label_manager.all_labels)
)
for k in self._get_deep_supervision_scales()
]
else:
raise NotImplementedError("This trainer does not support deep supervision")
self.dummy_batch = {"data": dummy_data, "target": dummy_target}
def get_dataloaders(self):
return None, None
def run_training(self):
try:
self.on_train_start()
for epoch in range(self.current_epoch, self.num_epochs):
self.on_epoch_start()
self.on_train_epoch_start()
train_outputs = []
for batch_id in range(self.num_iterations_per_epoch):
train_outputs.append(self.train_step(self.dummy_batch))
self.on_train_epoch_end(train_outputs)
with torch.no_grad():
self.on_validation_epoch_start()
val_outputs = []
for batch_id in range(self.num_val_iterations_per_epoch):
val_outputs.append(self.validation_step(self.dummy_batch))
self.on_validation_epoch_end(val_outputs)
self.on_epoch_end()
self.on_train_end()
except RuntimeError:
self.crashed_with_runtime_error = True
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/competitions/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/competitions/aortaseg24.py
================================================
from nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerNoMirroring import nnUNetTrainer_onlyMirror01
from nnunetv2.training.nnUNetTrainer.variants.data_augmentation.nnUNetTrainerDA5 import nnUNetTrainerDA5
class nnUNetTrainer_onlyMirror01_DA5(nnUNetTrainer_onlyMirror01, nnUNetTrainerDA5):
pass
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py
================================================
from typing import List, Union, Tuple
import numpy as np
import torch
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import AbstractTransform
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \
GammaTransform
from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform
from batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \
GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \
MirrorTransform
from batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \
NumpyToTensor
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from torch import autocast
from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
DownsampleSegForDSTransform2
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
ConvertSegmentationToRegionsTransform
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
Convert2DTo3DTransform
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.helpers import dummy_context
class TensorToNumpy(AbstractTransform):
def __init__(self, keys=None, cast_to=None):
"""
Converts torch tensors to numpy ndarrays.
:param keys: specify keys to be converted. If None then all tensor values will be converted.
Can be a key (string) or a list/tuple of keys.
:param cast_to: optional numpy dtype as string, e.g. 'float32', 'float16', 'int64', 'bool'
"""
if keys is not None and not isinstance(keys, (list, tuple)):
keys = [keys]
self.keys = keys
self.cast_to = cast_to
def cast(self, array: np.ndarray):
if self.cast_to is not None:
try:
array = array.astype(self.cast_to, copy=False)
except TypeError:
raise ValueError(f"Unknown value for cast_to: {self.cast_to}")
return array
def _to_numpy(self, tensor):
import torch
if isinstance(tensor, torch.Tensor):
# Important: detach + move to CPU before numpy conversion
array = tensor.detach().cpu().numpy()
return self.cast(array)
return tensor
def __call__(self, **data_dict):
import torch
if self.keys is None:
for key, val in data_dict.items():
if isinstance(val, torch.Tensor):
data_dict[key] = self._to_numpy(val)
elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val):
data_dict[key] = [self._to_numpy(i) for i in val]
else:
for key in self.keys:
val = data_dict[key]
if isinstance(val, torch.Tensor):
data_dict[key] = self._to_numpy(val)
elif isinstance(val, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in val):
data_dict[key] = [self._to_numpy(i) for i in val]
return data_dict
class nnUNetTrainerDA5(nnUNetTrainer):
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
# todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
if dim == 2:
do_dummy_2d_data_aug = False
# todo revisit this parametrization
if max(patch_size) / min(patch_size) > 1.5:
rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
else:
rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
mirror_axes = (0, 1)
elif dim == 3:
# todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
# order of the axes is determined by spacing, not image size
do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
if do_dummy_2d_data_aug:
# why do we rotate 180 deg here all the time? We should also restrict it
rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
else:
rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
mirror_axes = (0, 1, 2)
else:
raise RuntimeError()
# todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
# old nnunet for now)
initial_patch_size = get_patch_size(patch_size[-dim:],
rotation_for_DA,
rotation_for_DA,
rotation_for_DA,
(0.7, 1.43))
if do_dummy_2d_data_aug:
initial_patch_size[0] = patch_size[0]
self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
self.inference_allowed_mirroring_axes = mirror_axes
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> AbstractTransform:
matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])
valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])
tr_transforms = []
tr_transforms.append(TensorToNumpy())
tr_transforms.append(RenameTransform('target', 'seg', True))
if do_dummy_2d_data_aug:
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
tr_transforms.append(
SpatialTransform(
patch_size_spatial,
patch_center_dist_from_border=None,
do_elastic_deform=False,
do_rotation=True,
angle_x=rotation_for_DA,
angle_y=rotation_for_DA,
angle_z=rotation_for_DA,
p_rot_per_axis=0.5,
do_scale=True,
scale=(0.7, 1.43),
border_mode_data="constant",
border_cval_data=0,
order_data=3,
border_mode_seg="constant",
border_cval_seg=-1,
order_seg=1,
random_crop=False,
p_el_per_sample=0.2,
p_scale_per_sample=0.2,
p_rot_per_sample=0.4,
independent_scale_for_each_axis=True,
)
)
if do_dummy_2d_data_aug:
tr_transforms.append(Convert2DTo3DTransform())
if np.any(matching_axes > 1):
tr_transforms.append(
Rot90Transform(
(0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5
),
)
if np.any(matching_axes > 1):
tr_transforms.append(
TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)
)
tr_transforms.append(OneOfTransform([
MedianFilterTransform(
(2, 8),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
),
GaussianBlurTransform((0.3, 1.5),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5)
]))
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(BrightnessTransform(0,
0.5,
per_channel=True,
p_per_sample=0.1,
p_per_channel=0.5
)
)
tr_transforms.append(OneOfTransform(
[
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=True,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=False,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
]
))
tr_transforms.append(
SimulateLowResolutionTransform(zoom_range=(0.25, 1),
per_channel=True,
p_per_channel=0.5,
order_downsample=0,
order_upsample=3,
p_per_sample=0.15,
ignore_axes=ignore_axes
)
)
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
if mirror_axes is not None and len(mirror_axes) > 0:
tr_transforms.append(MirrorTransform(mirror_axes))
tr_transforms.append(
BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],
rectangle_value=np.mean,
num_rectangles=(1, 5),
force_square=False,
p_per_sample=0.4,
p_per_channel=0.5
)
)
tr_transforms.append(
BrightnessGradientAdditiveTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
max_strength=_brightness_gradient_additive_max_strength,
mean_centered=False,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
LocalGammaTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
_local_gamma_gamma,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
SharpeningTransform(
strength=(0.1, 1),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if is_cascaded:
if ignore_label is not None:
raise NotImplementedError('ignore label not yet supported in cascade')
assert foreground_labels is not None, 'We need all_labels for cascade augmentations'
use_labels = [i for i in foreground_labels if i != 0]
tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(use_labels), 0)),
p_per_sample=0.4,
key="data",
strel_size=(1, 8),
p_per_label=1))
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(use_labels), 0)),
key="data",
p_per_sample=0.2,
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
# the ignore label must also be converted
tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
if ignore_label is not None else regions,
'target', 'target'))
if deep_supervision_scales is not None:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
return tr_transforms
@staticmethod
def get_validation_transforms(
deep_supervision_scales: Union[List, Tuple, None],
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> AbstractTransform:
val_transforms = []
val_transforms.append(TensorToNumpy())
val_transforms.append(RenameTransform('target', 'seg', True))
val_transforms.append(RemoveLabelTransform(-1, 0))
if is_cascaded:
val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
# the ignore label must also be converted
val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
if ignore_label is not None else regions,
'target', 'target'))
if deep_supervision_scales is not None:
val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
return val_transforms
def get_dataloaders(self):
# we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
# we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
patch_size = self.configuration_manager.patch_size
# needed for deep supervision: how much do we need to downscale the segmentation targets for the different
# outputs?
deep_supervision_scales = self._get_deep_supervision_scales()
(
rotation_for_DA,
do_dummy_2d_data_aug,
initial_patch_size,
mirror_axes,
) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
# training pipeline
tr_transforms = self.get_training_transforms(
patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
# validation pipeline
val_transforms = self.get_validation_transforms(deep_supervision_scales,
is_cascaded=self.is_cascaded,
foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if
self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
dataset_tr, dataset_val = self.get_tr_and_val_datasets()
# we set transforms=None because this trainer still uses batchgenerators which expects transforms to be passed to
dl_tr = nnUNetDataLoader(dataset_tr, self.batch_size,
initial_patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None, transforms=None,
probabilistic_oversampling=self.probabilistic_oversampling)
dl_val = nnUNetDataLoader(dataset_val, self.batch_size,
self.configuration_manager.patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None, transforms=None,
probabilistic_oversampling=self.probabilistic_oversampling)
allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = NonDetMultiThreadedAugmenter(data_loader=dl_tr, transform=tr_transforms,
num_processes=allowed_num_processes, num_cached=6, seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.02)
mt_gen_val = NonDetMultiThreadedAugmenter(data_loader=dl_val,
transform=val_transforms,
num_processes=max(1, allowed_num_processes // 2),
num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
wait_time=0.02)
# # let's get this party started
_ = next(mt_gen_train)
_ = next(mt_gen_val)
return mt_gen_train, mt_gen_val
def validation_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
output = self.network(data)
del data
l = self.loss(output, target)
# we only need the output with the highest output resolution (if DS enabled)
if self.enable_deep_supervision:
output = output[0]
target = target[0]
# the following is needed for online evaluation. Fake dice (green line)
axes = [0] + list(range(2, output.ndim))
if self.label_manager.has_regions:
predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
else:
# no need for softmax
output_seg = output.argmax(1)[:, None]
predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float16)
predicted_segmentation_onehot.scatter_(1, output_seg, 1)
del output_seg
if self.label_manager.has_ignore_label:
if not self.label_manager.has_regions:
mask = target != self.label_manager.ignore_label
# CAREFUL that you don't rely on target after this line!
target[target == self.label_manager.ignore_label] = 0
else:
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = (1 - target[:, -1:]).bool()
# CAREFUL that you don't rely on target after this line!
target = target[:, :-1].bool()
else:
mask = None
tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
tp_hard = tp.detach().cpu().numpy()
fp_hard = fp.detach().cpu().numpy()
fn_hard = fn.detach().cpu().numpy()
if not self.label_manager.has_regions:
# if we train with regions all segmentation heads predict some kind of foreground. In conventional
# (softmax training) there needs tobe one output for the background. We are not interested in the
# background Dice
# [1:] in order to remove background
tp_hard = tp_hard[1:]
fp_hard = fp_hard[1:]
fn_hard = fn_hard[1:]
return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
class nnUNetTrainerDA5ord0(nnUNetTrainerDA5):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> AbstractTransform:
matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])
valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])
tr_transforms = []
tr_transforms.append(RenameTransform('target', 'seg', True))
if do_dummy_2d_data_aug:
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
tr_transforms.append(
SpatialTransform(
patch_size_spatial,
patch_center_dist_from_border=None,
do_elastic_deform=False,
do_rotation=True,
angle_x=rotation_for_DA,
angle_y=rotation_for_DA,
angle_z=rotation_for_DA,
p_rot_per_axis=0.5,
do_scale=True,
scale=(0.7, 1.43),
border_mode_data="constant",
border_cval_data=0,
order_data=0,
border_mode_seg="constant",
border_cval_seg=-1,
order_seg=0,
random_crop=False,
p_el_per_sample=0.2,
p_scale_per_sample=0.2,
p_rot_per_sample=0.4,
independent_scale_for_each_axis=True,
)
)
if do_dummy_2d_data_aug:
tr_transforms.append(Convert2DTo3DTransform())
if np.any(matching_axes > 1):
tr_transforms.append(
Rot90Transform(
(0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5
),
)
if np.any(matching_axes > 1):
tr_transforms.append(
TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)
)
tr_transforms.append(OneOfTransform([
MedianFilterTransform(
(2, 8),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
),
GaussianBlurTransform((0.3, 1.5),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5)
]))
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(BrightnessTransform(0,
0.5,
per_channel=True,
p_per_sample=0.1,
p_per_channel=0.5
)
)
tr_transforms.append(OneOfTransform(
[
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=True,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=False,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
]
))
tr_transforms.append(
SimulateLowResolutionTransform(zoom_range=(0.25, 1),
per_channel=True,
p_per_channel=0.5,
order_downsample=0,
order_upsample=3,
p_per_sample=0.15,
ignore_axes=ignore_axes
)
)
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
if mirror_axes is not None and len(mirror_axes) > 0:
tr_transforms.append(MirrorTransform(mirror_axes))
tr_transforms.append(
BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],
rectangle_value=np.mean,
num_rectangles=(1, 5),
force_square=False,
p_per_sample=0.4,
p_per_channel=0.5
)
)
tr_transforms.append(
BrightnessGradientAdditiveTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
max_strength=_brightness_gradient_additive_max_strength,
mean_centered=False,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
LocalGammaTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
_local_gamma_gamma,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
SharpeningTransform(
strength=(0.1, 1),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if is_cascaded:
if ignore_label is not None:
raise NotImplementedError('ignore label not yet supported in cascade')
assert foreground_labels is not None, 'We need all_labels for cascade augmentations'
use_labels = [i for i in foreground_labels if i != 0]
tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(use_labels), 0)),
p_per_sample=0.4,
key="data",
strel_size=(1, 8),
p_per_label=1))
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(use_labels), 0)),
key="data",
p_per_sample=0.2,
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
# the ignore label must also be converted
tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
if ignore_label is not None else regions,
'target', 'target'))
if deep_supervision_scales is not None:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
return tr_transforms
def _brightnessadditive_localgamma_transform_scale(x, y):
return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y])))
def _brightness_gradient_additive_max_strength(_x, _y):
return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5)
def _local_gamma_gamma():
return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4)
class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> AbstractTransform:
matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])
valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])
tr_transforms = []
tr_transforms.append(RenameTransform('target', 'seg', True))
if do_dummy_2d_data_aug:
ignore_axes = (0,)
tr_transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
tr_transforms.append(
SpatialTransform(
patch_size_spatial,
patch_center_dist_from_border=None,
do_elastic_deform=False,
do_rotation=True,
angle_x=rotation_for_DA,
angle_y=rotation_for_DA,
angle_z=rotation_for_DA,
p_rot_per_axis=0.5,
do_scale=True,
scale=(0.7, 1.43),
border_mode_data="constant",
border_cval_data=0,
order_data=3,
border_mode_seg="constant",
border_cval_seg=-1,
order_seg=0,
random_crop=False,
p_el_per_sample=0.2,
p_scale_per_sample=0.2,
p_rot_per_sample=0.4,
independent_scale_for_each_axis=True,
)
)
if do_dummy_2d_data_aug:
tr_transforms.append(Convert2DTo3DTransform())
if np.any(matching_axes > 1):
tr_transforms.append(
Rot90Transform(
(0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5
),
)
if np.any(matching_axes > 1):
tr_transforms.append(
TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)
)
tr_transforms.append(OneOfTransform([
MedianFilterTransform(
(2, 8),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
),
GaussianBlurTransform((0.3, 1.5),
different_sigma_per_channel=True,
p_per_sample=0.2,
p_per_channel=0.5)
]))
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(BrightnessTransform(0,
0.5,
per_channel=True,
p_per_sample=0.1,
p_per_channel=0.5
)
)
tr_transforms.append(OneOfTransform(
[
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=True,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
ContrastAugmentationTransform(
contrast_range=(0.5, 2),
preserve_range=False,
per_channel=True,
data_key='data',
p_per_sample=0.2,
p_per_channel=0.5
),
]
))
tr_transforms.append(
SimulateLowResolutionTransform(zoom_range=(0.25, 1),
per_channel=True,
p_per_channel=0.5,
order_downsample=0,
order_upsample=3,
p_per_sample=0.15,
ignore_axes=ignore_axes
)
)
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
tr_transforms.append(
GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
if mirror_axes is not None and len(mirror_axes) > 0:
tr_transforms.append(MirrorTransform(mirror_axes))
tr_transforms.append(
BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],
rectangle_value=np.mean,
num_rectangles=(1, 5),
force_square=False,
p_per_sample=0.4,
p_per_channel=0.5
)
)
tr_transforms.append(
BrightnessGradientAdditiveTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
max_strength=_brightness_gradient_additive_max_strength,
mean_centered=False,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
LocalGammaTransform(
_brightnessadditive_localgamma_transform_scale,
(-0.5, 1.5),
_local_gamma_gamma,
same_for_all_channels=False,
p_per_sample=0.3,
p_per_channel=0.5
)
)
tr_transforms.append(
SharpeningTransform(
strength=(0.1, 1),
same_for_each_channel=False,
p_per_sample=0.2,
p_per_channel=0.5
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
if is_cascaded:
if ignore_label is not None:
raise NotImplementedError('ignore label not yet supported in cascade')
assert foreground_labels is not None, 'We need all_labels for cascade augmentations'
use_labels = [i for i in foreground_labels if i != 0]
tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))
tr_transforms.append(ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(use_labels), 0)),
p_per_sample=0.4,
key="data",
strel_size=(1, 8),
p_per_label=1))
tr_transforms.append(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(use_labels), 0)),
key="data",
p_per_sample=0.2,
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
# the ignore label must also be converted
tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
if ignore_label is not None else regions,
'target', 'target'))
if deep_supervision_scales is not None:
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
output_key='target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
return tr_transforms
class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 10
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py
================================================
from typing import Union, Tuple, List
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
import numpy as np
class nnUNetTrainer_DASegOrd0(nnUNetTrainer):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
transforms = []
if do_dummy_2d_data_aug:
ignore_axes = (0,)
transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
transforms.append(
SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
p_rotation=0.2,
rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,
bg_style_seg_sampling=False, mode_seg='nearest'
)
)
if do_dummy_2d_data_aug:
transforms.append(Convert2DTo3DTransform())
transforms.append(RandomTransform(
GaussianNoiseTransform(
noise_variance=(0, 0.1),
p_per_channel=1,
synchronize_channels=True
), apply_probability=0.1
))
transforms.append(RandomTransform(
GaussianBlurTransform(
blur_sigma=(0.5, 1.),
synchronize_channels=False,
synchronize_axes=False,
p_per_channel=0.5, benchmark=True
), apply_probability=0.2
))
transforms.append(RandomTransform(
MultiplicativeBrightnessTransform(
multiplier_range=BGContrast((0.75, 1.25)),
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
ContrastTransform(
contrast_range=BGContrast((0.75, 1.25)),
preserve_range=True,
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
SimulateLowResolutionTransform(
scale=(0.5, 1),
synchronize_channels=False,
synchronize_axes=True,
ignore_axes=ignore_axes,
allowed_channels=None,
p_per_channel=0.5
), apply_probability=0.25
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=1,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.1
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=0,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.3
))
if mirror_axes is not None and len(mirror_axes) > 0:
transforms.append(
MirrorTransform(
allowed_axes=mirror_axes
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
transforms.append(MaskImageTransform(
apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
channel_idx_in_seg=0,
set_outside_to=0,
))
transforms.append(
RemoveLabelTansform(-1, 0)
)
if is_cascaded:
assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
transforms.append(
MoveSegAsOneHotToDataTransform(
source_channel_idx=1,
all_labels=foreground_labels,
remove_channel_from_source=True
)
)
transforms.append(
RandomTransform(
ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
strel_size=(1, 8),
p_per_label=1
), apply_probability=0.4
)
)
transforms.append(
RandomTransform(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15,
p_per_label=1
), apply_probability=0.2
)
)
if regions is not None:
# the ignore label must also be converted
transforms.append(
ConvertSegmentationToRegionsTransform(
regions=list(regions) + [ignore_label] if ignore_label is not None else regions,
channel_in_seg=0
)
)
if deep_supervision_scales is not None:
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
return ComposeTransforms(transforms)
class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer_DASegOrd0):
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py
================================================
from typing import Union, Tuple, List
import numpy as np
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerNoDA(nnUNetTrainer):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels,
regions, ignore_label)
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
# we need to disable mirroring here so that no mirroring will be applied in inference!
rotation_for_DA, do_dummy_2d_data_aug, _, _ = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
initial_patch_size = self.configuration_manager.patch_size
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py
================================================
from typing import Union, Tuple, List
import numpy as np
import torch
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerNoMirroring(nnUNetTrainer):
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
class nnUNetTrainer_onlyMirror01(nnUNetTrainer):
"""
Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D
"""
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
if dim == 2:
mirror_axes = (0, )
else:
mirror_axes = (0, 1)
self.inference_allowed_mirroring_axes = mirror_axes
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
class nnUNetTrainer_onlyMirror01_1500ep(nnUNetTrainer_onlyMirror01):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 1500
class nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
transforms = []
if do_dummy_2d_data_aug:
ignore_axes = (0,)
transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
transforms.append(
SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
p_rotation=0.2,
rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,
bg_style_seg_sampling=False, mode_seg='nearest'
)
)
if do_dummy_2d_data_aug:
transforms.append(Convert2DTo3DTransform())
transforms.append(RandomTransform(
GaussianNoiseTransform(
noise_variance=(0, 0.1),
p_per_channel=1,
synchronize_channels=True
), apply_probability=0.1
))
transforms.append(RandomTransform(
GaussianBlurTransform(
blur_sigma=(0.5, 1.),
synchronize_channels=False,
synchronize_axes=False,
p_per_channel=0.5, benchmark=True
), apply_probability=0.2
))
transforms.append(RandomTransform(
MultiplicativeBrightnessTransform(
multiplier_range=BGContrast((0.75, 1.25)),
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
ContrastTransform(
contrast_range=BGContrast((0.75, 1.25)),
preserve_range=True,
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
SimulateLowResolutionTransform(
scale=(0.5, 1),
synchronize_channels=False,
synchronize_axes=True,
ignore_axes=ignore_axes,
allowed_channels=None,
p_per_channel=0.5
), apply_probability=0.25
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=1,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.1
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=0,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.3
))
if mirror_axes is not None and len(mirror_axes) > 0:
transforms.append(
MirrorTransform(
allowed_axes=mirror_axes
)
)
if use_mask_for_norm is not None and any(use_mask_for_norm):
transforms.append(MaskImageTransform(
apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
channel_idx_in_seg=0,
set_outside_to=0,
))
transforms.append(
RemoveLabelTansform(-1, 0)
)
if is_cascaded:
assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
transforms.append(
MoveSegAsOneHotToDataTransform(
source_channel_idx=1,
all_labels=foreground_labels,
remove_channel_from_source=True
)
)
transforms.append(
RandomTransform(
ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
strel_size=(1, 8),
p_per_label=1
), apply_probability=0.4
)
)
transforms.append(
RandomTransform(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15,
p_per_label=1
), apply_probability=0.2
)
)
if regions is not None:
# the ignore label must also be converted
transforms.append(
ConvertSegmentationToRegionsTransform(
regions=list(regions) + [ignore_label] if ignore_label is not None else regions,
channel_in_seg=0
)
)
if deep_supervision_scales is not None:
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
return ComposeTransforms(transforms)
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainer_noDummy2DDA.py
================================================
from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
import numpy as np
class nnUNetTrainer_noDummy2DDA(nnUNetTrainer):
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
do_dummy_2d_data_aug = False
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
if dim == 2:
if max(patch_size) / min(patch_size) > 1.5:
rotation_for_DA = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
else:
rotation_for_DA = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
mirror_axes = (0, 1)
elif dim == 3:
rotation_for_DA = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
mirror_axes = (0, 1, 2)
else:
raise RuntimeError()
# todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
# old nnunet for now)
initial_patch_size = get_patch_size(patch_size[-dim:],
rotation_for_DA,
rotation_for_DA,
rotation_for_DA,
(0.85, 1.25))
self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
self.inference_allowed_mirroring_axes = mirror_axes
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py
================================================
import torch
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss
import numpy as np
class nnUNetTrainerCELoss(nnUNetTrainer):
def _build_loss(self):
assert not self.label_manager.has_regions, "regions not supported by this trainer"
loss = RobustCrossEntropyLoss(
weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100
)
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 5
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py
================================================
import numpy as np
import torch
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.helpers import softmax_helper_dim1
class nnUNetTrainerDiceLoss(nnUNetTrainer):
def _build_loss(self):
loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer):
def _build_loss(self):
# set smooth to 0
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{'batch_dice': self.configuration_manager.batch_dice,
'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py
================================================
from nnunetv2.training.loss.compound_losses import DC_and_topk_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
import numpy as np
from nnunetv2.training.loss.robust_ce_loss import TopKLoss
class nnUNetTrainerTopk10Loss(nnUNetTrainer):
def _build_loss(self):
assert not self.label_manager.has_regions, "regions not supported by this trainer"
loss = TopKLoss(
ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10
)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
class nnUNetTrainerTopk10LossLS01(nnUNetTrainer):
def _build_loss(self):
assert not self.label_manager.has_regions, "regions not supported by this trainer"
loss = TopKLoss(
ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100,
k=10,
label_smoothing=0.1,
)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer):
def _build_loss(self):
assert not self.label_manager.has_regions, "regions not supported by this trainer"
loss = DC_and_topk_loss(
{"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp},
{"k": 10, "label_smoothing": 0.0},
weight_ce=1,
weight_dice=1,
ignore_label=self.label_manager.ignore_label,
)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py
================================================
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerCosAnneal(nnUNetTrainer):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
momentum=0.99, nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
return optimizer, lr_scheduler
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainer_warmup.py
================================================
from typing import Union
import torch
from torch._dynamo import OptimizedModule
from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.utilities.helpers import empty_cache
class nnUNetTrainer_warmup(nnUNetTrainer):
"""
Does a warmup of the entire architecture
Then does normal training
"""
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
#### hyperparameters for warmup
self.warmup_duration_whole_net = 50 # lin increase whole network
self.num_epochs = 1000
self.training_stage = None # 'warmup_all', 'train'
def configure_optimizers(self, stage: str = "warmup_all"):
assert stage in ["warmup_all", "train"]
if self.training_stage == stage:
return self.optimizer, self.lr_scheduler
if isinstance(self.network, DDP):
params = self.network.module.parameters()
else:
params = self.network.parameters()
if stage == "warmup_all":
self.print_to_log_file("train whole net, warmup")
optimizer = torch.optim.SGD(
params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True
)
lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)
self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}")
else:
self.print_to_log_file("train whole net, default schedule")
if self.training_stage == "warmup_all":
# we can keep the existing optimizer and don't need to create a new one. This will allow us to keep
# the accumulated momentum terms which already point in a useful driection
optimizer = self.optimizer
else:
optimizer = torch.optim.SGD(
params, self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True
)
lr_scheduler = PolyLRScheduler_offset(
optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net
)
self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}")
self.training_stage = stage
empty_cache(self.device)
return optimizer, lr_scheduler
def on_train_epoch_start(self):
if self.current_epoch == 0:
self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all")
elif self.current_epoch == self.warmup_duration_whole_net:
self.optimizer, self.lr_scheduler = self.configure_optimizers("train")
super().on_train_epoch_start()
def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
"""
We need to overwrite that entire function because we need to fiddle the correct optimizer in between
loading the checkpoint and applying the optimizer states. Yuck.
"""
if not self.was_initialized:
self.initialize()
if isinstance(filename_or_checkpoint, str):
checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
# if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
# match. Use heuristic to make it match
new_state_dict = {}
for k, value in checkpoint["network_weights"].items():
key = k
if key not in self.network.state_dict().keys() and key.startswith("module."):
key = key[7:]
new_state_dict[key] = value
self.my_init_kwargs = checkpoint["init_args"]
self.current_epoch = checkpoint["current_epoch"]
self.logger.load_checkpoint(checkpoint["logging"])
self._best_ema = checkpoint["_best_ema"]
self.inference_allowed_mirroring_axes = (
checkpoint["inference_allowed_mirroring_axes"]
if "inference_allowed_mirroring_axes" in checkpoint.keys()
else self.inference_allowed_mirroring_axes
)
# messing with state dict naming schemes. Facepalm.
if self.is_ddp:
if isinstance(self.network.module, OptimizedModule):
self.network.module._orig_mod.load_state_dict(new_state_dict)
else:
self.network.module.load_state_dict(new_state_dict)
else:
if isinstance(self.network, OptimizedModule):
self.network._orig_mod.load_state_dict(new_state_dict)
else:
self.network.load_state_dict(new_state_dict)
# it's fine to do this every time we load because configure_optimizers will be a no-op if the correct optimizer
# and lr scheduler are already set up
if self.current_epoch < self.warmup_duration_whole_net:
self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all")
else:
self.optimizer, self.lr_scheduler = self.configure_optimizers("train")
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
if self.grad_scaler is not None:
if checkpoint["grad_scaler_state"] is not None:
self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state"])
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py
================================================
from typing import Union, Tuple, List
from dynamic_network_architectures.building_blocks.helper import get_matching_batchnorm
from torch import nn
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerBN(nnUNetTrainer):
@staticmethod
def build_network_architecture(architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True) -> nn.Module:
if 'norm_op' not in arch_init_kwargs.keys():
raise RuntimeError("'norm_op' not found in arch_init_kwargs. This does not look like an architecture "
"I can hack BN into. This trainer only works with default nnU-Net architectures.")
from pydoc import locate
conv_op = locate(arch_init_kwargs['conv_op'])
bn_class = get_matching_batchnorm(conv_op)
arch_init_kwargs['norm_op'] = bn_class.__module__ + '.' + bn_class.__name__
arch_init_kwargs['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
return nnUNetTrainer.build_network_architecture(architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels, enable_deep_supervision)
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py
================================================
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
import torch
class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.enable_deep_supervision = False
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py
================================================
import torch
from torch.optim import Adam, AdamW
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerAdam(nnUNetTrainer):
def configure_optimizers(self):
optimizer = AdamW(self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay,
amsgrad=True)
# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
# momentum=0.99, nesterov=True)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
return optimizer, lr_scheduler
class nnUNetTrainerVanillaAdam(nnUNetTrainer):
def configure_optimizers(self):
optimizer = Adam(self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay)
# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
# momentum=0.99, nesterov=True)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
return optimizer, lr_scheduler
class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 1e-3
class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam):
# https://twitter.com/karpathy/status/801621764144971776?lang=en
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
class nnUNetTrainerAdam1en3(nnUNetTrainerAdam):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 1e-3
class nnUNetTrainerAdam3en4(nnUNetTrainerAdam):
# https://twitter.com/karpathy/status/801621764144971776?lang=en
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py
================================================
import torch
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from torch.optim.lr_scheduler import CosineAnnealingLR
try:
from adan_pytorch import Adan
except ImportError:
Adan = None
class nnUNetTrainerAdan(nnUNetTrainer):
def configure_optimizers(self):
if Adan is None:
raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
optimizer = Adan(self.network.parameters(),
lr=self.initial_lr,
# betas=(0.02, 0.08, 0.01), defaults
weight_decay=self.weight_decay)
# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
# momentum=0.99, nesterov=True)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
return optimizer, lr_scheduler
class nnUNetTrainerAdan1en3(nnUNetTrainerAdan):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 1e-3
class nnUNetTrainerAdan3en4(nnUNetTrainerAdan):
# https://twitter.com/karpathy/status/801621764144971776?lang=en
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
class nnUNetTrainerAdan1en1(nnUNetTrainerAdan):
# this trainer makes no sense -> nan!
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 1e-1
class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan):
# def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
# device: torch.device = torch.device('cuda')):
# super().__init__(plans, configuration, fold, dataset_json, device)
# self.num_epochs = 15
def configure_optimizers(self):
if Adan is None:
raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
optimizer = Adan(self.network.parameters(),
lr=self.initial_lr,
# betas=(0.02, 0.08, 0.01), defaults
weight_decay=self.weight_decay)
# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
# momentum=0.99, nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
return optimizer, lr_scheduler
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py
================================================
import numpy as np
import torch
from torch import distributed as dist
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer):
"""
sampling of foreground happens randomly and not for the last 33% of samples in a batch
since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can
be 50%
Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible.
If we switch to this oversampling then we can keep it at a constant 0.33 or whatever.
"""
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.probabilistic_oversampling = True
self.oversample_foreground_percent = float(np.mean(
[not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent))
for sample_idx in range(self.configuration_manager.batch_size)]))
self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}")
def _set_batch_size_and_oversample(self):
if not self.is_ddp:
# set batch size to what the plan says, leave oversample untouched
self.batch_size = self.configuration_manager.batch_size
else:
# batch size is distributed over DDP workers and we need to change oversample_percent for each worker
world_size = dist.get_world_size()
my_rank = dist.get_rank()
global_batch_size = self.configuration_manager.batch_size
assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
'GPUs... Duh.'
batch_size_per_GPU = [global_batch_size // world_size] * world_size
batch_size_per_GPU = [batch_size_per_GPU[i] + 1
if (batch_size_per_GPU[i] * world_size + i) < global_batch_size
else batch_size_per_GPU[i]
for i in range(len(batch_size_per_GPU))]
assert sum(batch_size_per_GPU) == global_batch_size
print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank])
print("worker", my_rank, "oversample", self.oversample_foreground_percent)
self.batch_size = batch_size_per_GPU[my_rank]
class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.oversample_foreground_percent = 0.33
class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.oversample_foreground_percent = 0.1
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py
================================================
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py
================================================
import torch
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainer_5epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 5
class nnUNetTrainer_1epoch(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 1
class nnUNetTrainer_10epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 10
class nnUNetTrainer_20epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 20
class nnUNetTrainer_50epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 50
class nnUNetTrainer_100epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 100
class nnUNetTrainer_250epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 250
class nnUNetTrainer_500epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 500
class nnUNetTrainer_750epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 750
class nnUNetTrainer_2000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 2000
class nnUNetTrainer_4000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 4000
class nnUNetTrainer_8000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 8000
================================================
FILE: nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py
================================================
import torch
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 250
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 2000
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 4000
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 8000
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
================================================
FILE: nnunetv2/utilities/__init__.py
================================================
================================================
FILE: nnunetv2/utilities/collate_outputs.py
================================================
from typing import List
import numpy as np
def collate_outputs(outputs: List[dict]):
"""
used to collate default train_step and validation_step outputs. If you want something different then you gotta
extend this
we expect outputs to be a list of dictionaries where each of the dict has the same set of keys
"""
collated = {}
for k in outputs[0].keys():
if np.isscalar(outputs[0][k]):
collated[k] = [o[k] for o in outputs]
elif isinstance(outputs[0][k], np.ndarray):
collated[k] = np.vstack([o[k][None] for o in outputs])
elif isinstance(outputs[0][k], list):
collated[k] = [item for o in outputs for item in o[k]]
else:
raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. '
f'Modify collate_outputs to add this functionality')
return collated
================================================
FILE: nnunetv2/utilities/crossval_split.py
================================================
from typing import List
import numpy as np
from sklearn.model_selection import KFold
def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]:
splits = []
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)):
train_keys = np.array(train_identifiers)[train_idx]
test_keys = np.array(train_identifiers)[test_idx]
splits.append({})
splits[-1]['train'] = list(train_keys)
splits[-1]['val'] = list(test_keys)
return splits
================================================
FILE: nnunetv2/utilities/dataset_name_id_conversion.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
def find_candidate_datasets(dataset_id: int):
startswith = "Dataset%03.0d" % dataset_id
if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed):
candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False)
else:
candidates_preprocessed = []
if nnUNet_raw is not None and isdir(nnUNet_raw):
candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False)
else:
candidates_raw = []
candidates_trained_models = []
if nnUNet_results is not None and isdir(nnUNet_results):
candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False)
all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models
unique_candidates = np.unique(all_candidates)
return unique_candidates
def convert_id_to_dataset_name(dataset_id: int):
unique_candidates = find_candidate_datasets(dataset_id)
if len(unique_candidates) > 1:
raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the "
"following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results))
if len(unique_candidates) == 0:
raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID "
f"exists and that nnU-Net knows where raw and preprocessed data are located "
f"(see Documentation - Installation). Here are your currently defined folders:\n"
f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n"
f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n"
f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n"
f"If something is not right, adapt your environment variables.")
return unique_candidates[0]
def convert_dataset_name_to_id(dataset_name: str):
assert dataset_name.startswith("Dataset")
dataset_id = int(dataset_name[7:10])
return dataset_id
def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str:
if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"):
return dataset_name_or_id
if isinstance(dataset_name_or_id, str):
try:
dataset_name_or_id = int(dataset_name_or_id)
except ValueError:
raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to "
"convert it to a dataset ID (int). That failed, however. Please give an integer number "
"('1', '2', etc) or a correct dataset name. Your input: %s" % dataset_name_or_id)
return convert_id_to_dataset_name(dataset_name_or_id)
================================================
FILE: nnunetv2/utilities/ddp_allgather.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple
import torch
from torch import distributed
def print_if_rank0(*args):
if distributed.get_rank() == 0:
print(*args)
class AllGatherGrad(torch.autograd.Function):
# stolen from pytorch lightning
@staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
group: Optional["torch.distributed.ProcessGroup"] = None,
) -> torch.Tensor:
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor
@staticmethod
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None
================================================
FILE: nnunetv2/utilities/default_n_proc_DA.py
================================================
import subprocess
import os
def get_allowed_n_proc_DA():
"""
This function is used to set the number of processes used on different Systems. It is specific to our cluster
infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed.
IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script
(see first line).
Interpret the output as the number of processes used for data augmentation PER GPU.
The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our
systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per
GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use.
"""
if 'nnUNet_n_proc_DA' in os.environ.keys():
use_this = int(os.environ['nnUNet_n_proc_DA'])
else:
hostname = subprocess.getoutput(['hostname'])
if hostname in ['Fabian', 'isensee-']:
use_this = 12
elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']:
use_this = 16
elif hostname.startswith('e230-dgx1'):
use_this = 10
elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'):
use_this = 16
elif hostname.startswith('e230-dgx2'):
use_this = 6
elif hostname.startswith('e230-dgxa100-'):
use_this = 28
elif hostname.startswith('e230-thinka100-'):
use_this = 20
elif hostname.startswith('lsf22-gpu'):
use_this = 28
elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'):
use_this = 12
elif 'superh200' in hostname.lower():
use_this = 48
elif 'superl40s' in hostname.lower():
use_this = 28
else:
use_this = 12 # default value
use_this = min(use_this, os.cpu_count())
return use_this
================================================
FILE: nnunetv2/utilities/file_path_utilities.py
================================================
from multiprocessing import Pool
from typing import Union, Tuple
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.configuration import default_num_processes
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration):
return f'{trainer_name}__{plans_identifier}__{configuration}'
def convert_identifier_to_trainer_plans_config(identifier: str):
return os.path.basename(identifier).split('__')
def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',
fold: Union[str, int] = None) -> str:
tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id),
convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration))
if fold is not None:
tmp = join(tmp, f'fold_{fold}')
return tmp
def parse_dataset_trainer_plans_configuration_from_path(path: str):
folders = split_path(path)
# this here can be a little tricky because we are making assumptions. Let's hope this never fails lol
# safer to make this depend on two conditions, the fold_x and the DatasetXXX
# first let's see if some fold_X is present
fold_x_present = [i.startswith('fold_') for i in folders]
if any(fold_x_present):
idx = fold_x_present.index(True)
# OK now two entries before that there should be DatasetXXX
assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
if folders[idx - 2].startswith('Dataset'):
split = folders[idx - 1].split('__')
assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
return folders[idx - 2], *split
else:
# we can only check for dataset followed by a string that is separable into three strings by splitting with '__'
# look for DatasetXXX
dataset_folder = [i.startswith('Dataset') for i in folders]
if any(dataset_folder):
idx = dataset_folder.index(True)
assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \
'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
split = folders[idx + 1].split('__')
assert len(split) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
return folders[idx], *split
def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]):
identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \
os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds)
return identifier
def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]):
model1_folder = get_output_folder(dataset, tr1, p1, c1)
model2_folder = get_output_folder(dataset, tr2, p2, c2)
get_ensemble_name(model1_folder, model2_folder, folds)
def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str):
prefix, *models, folds = os.path.basename(ensemble_folder).split('___')
return models, folds
def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]):
s = str(folds[0])
for f in folds[1:]:
s += f"_{f}"
return s
def folds_string_to_tuple(folds_string: str):
folds = folds_string.split('_')
res = []
for f in folds:
try:
res.append(int(f))
except ValueError:
res.append(f)
return res
def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0):
"""
returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued
"""
alive = [i.is_alive() for i in worker_list]
if not all(alive):
raise RuntimeError('Some background workers are no longer alive')
not_ready = [not i.ready() for i in results_list]
if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued):
return True
return False
if __name__ == '__main__':
### well at this point I could just write tests...
path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'
print(parse_dataset_trainer_plans_configuration_from_path(path))
path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'
print(parse_dataset_trainer_plans_configuration_from_path(path))
path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all'
print(parse_dataset_trainer_plans_configuration_from_path(path))
try:
path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/'
print(parse_dataset_trainer_plans_configuration_from_path(path))
except AssertionError:
print('yayy, assertion works')
================================================
FILE: nnunetv2/utilities/find_class_by_name.py
================================================
import importlib
import pkgutil
from batchgenerators.utilities.file_and_folder_operations import *
def recursive_find_python_class(folder: str, class_name: str, current_module: str):
tr = None
for importer, modname, ispkg in pkgutil.iter_modules([folder]):
# print(modname, ispkg)
if not ispkg:
m = importlib.import_module(current_module + "." + modname)
if hasattr(m, class_name):
tr = getattr(m, class_name)
break
if tr is None:
for importer, modname, ispkg in pkgutil.iter_modules([folder]):
if ispkg:
next_current_module = current_module + "." + modname
tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
if tr is not None:
break
return tr
================================================
FILE: nnunetv2/utilities/get_network_from_plans.py
================================================
import pydoc
import warnings
from typing import Union
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from batchgenerators.utilities.file_and_folder_operations import join
def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels, output_channels,
allow_init=True, deep_supervision: Union[bool, None] = None):
network_class = arch_class_name
architecture_kwargs = dict(**arch_kwargs)
for ri in arch_kwargs_req_import:
if architecture_kwargs[ri] is not None:
architecture_kwargs[ri] = pydoc.locate(architecture_kwargs[ri])
nw_class = pydoc.locate(network_class)
# sometimes things move around, this makes it so that we can at least recover some of that
if nw_class is None:
warnings.warn(f'Network class {network_class} not found. Attempting to locate it within '
f'dynamic_network_architectures.architectures...')
import dynamic_network_architectures
nw_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"),
network_class.split(".")[-1],
'dynamic_network_architectures.architectures')
if nw_class is not None:
print(f'FOUND IT: {nw_class}')
else:
raise ImportError('Network class could not be found, please check/correct your plans file')
if deep_supervision is not None:
architecture_kwargs['deep_supervision'] = deep_supervision
network = nw_class(
input_channels=input_channels,
num_classes=output_channels,
**architecture_kwargs
)
if hasattr(network, 'initialize') and allow_init:
network.apply(network.initialize)
return network
if __name__ == "__main__":
import torch
model = get_network_from_plans(
arch_class_name="dynamic_network_architectures.architectures.unet.ResidualEncoderUNet",
arch_kwargs={
"n_stages": 7,
"features_per_stage": [32, 64, 128, 256, 512, 512, 512],
"conv_op": "torch.nn.modules.conv.Conv2d",
"kernel_sizes": [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
"strides": [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
"n_blocks_per_stage": [1, 3, 4, 6, 6, 6, 6],
"n_conv_per_stage_decoder": [1, 1, 1, 1, 1, 1],
"conv_bias": True,
"norm_op": "torch.nn.modules.instancenorm.InstanceNorm2d",
"norm_op_kwargs": {"eps": 1e-05, "affine": True},
"dropout_op": None,
"dropout_op_kwargs": None,
"nonlin": "torch.nn.LeakyReLU",
"nonlin_kwargs": {"inplace": True},
},
arch_kwargs_req_import=["conv_op", "norm_op", "dropout_op", "nonlin"],
input_channels=1,
output_channels=4,
allow_init=True,
deep_supervision=True,
)
data = torch.rand((8, 1, 256, 256))
target = torch.rand(size=(8, 1, 256, 256))
outputs = model(data) # this should be a list of torch.Tensor
================================================
FILE: nnunetv2/utilities/helpers.py
================================================
import torch
def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor:
return torch.softmax(x, 0)
def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor:
return torch.softmax(x, 1)
def empty_cache(device: torch.device):
if device.type == 'cuda':
torch.cuda.empty_cache()
elif device.type == 'mps':
from torch import mps
mps.empty_cache()
else:
pass
class dummy_context(object):
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
================================================
FILE: nnunetv2/utilities/json_export.py
================================================
from collections.abc import Iterable
import numpy as np
import torch
def recursive_fix_for_json_export(my_dict: dict):
# json is ... a very nice thing to have
# 'cannot serialize object of type bool_/int64/float64'. Apart from that of course...
keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys....
for k in keys:
if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)):
tmp = my_dict[k]
del my_dict[k]
my_dict[int(k)] = tmp
del tmp
k = int(k)
if isinstance(my_dict[k], dict):
recursive_fix_for_json_export(my_dict[k])
elif isinstance(my_dict[k], np.ndarray):
assert my_dict[k].ndim == 1, 'only 1d arrays are supported'
my_dict[k] = fix_types_iterable(my_dict[k], output_type=list)
elif isinstance(my_dict[k], (np.bool_,)):
my_dict[k] = bool(my_dict[k])
elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)):
my_dict[k] = int(my_dict[k])
elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)):
my_dict[k] = float(my_dict[k])
elif isinstance(my_dict[k], list):
my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k]))
elif isinstance(my_dict[k], tuple):
my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple)
elif isinstance(my_dict[k], torch.device):
my_dict[k] = str(my_dict[k])
else:
pass # pray it can be serialized
def fix_types_iterable(iterable, output_type):
# this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep your hands off of this.
out = []
for i in iterable:
if type(i) in (np.int64, np.int32, np.int8, np.uint8):
out.append(int(i))
elif isinstance(i, dict):
recursive_fix_for_json_export(i)
out.append(i)
elif type(i) in (np.float32, np.float64, np.float16):
out.append(float(i))
elif type(i) in (np.bool_,):
out.append(bool(i))
elif isinstance(i, str):
out.append(i)
elif isinstance(i, Iterable):
# print('recursive call on', i, type(i))
out.append(fix_types_iterable(i, type(i)))
else:
out.append(i)
return output_type(out)
================================================
FILE: nnunetv2/utilities/label_handling/__init__.py
================================================
================================================
FILE: nnunetv2/utilities/label_handling/label_handling.py
================================================
from __future__ import annotations
from time import time
from typing import Union, List, Tuple, Type
import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice, insert_crop_into_image
from batchgenerators.utilities.file_and_folder_operations import join
import nnunetv2
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import softmax_helper_dim0
from typing import TYPE_CHECKING
# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
if TYPE_CHECKING:
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
class LabelManager(object):
def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False,
inference_nonlin=None):
self._sanity_check(label_dict)
self.label_dict = label_dict
self.regions_class_order = regions_class_order
self._force_use_labels = force_use_labels
if force_use_labels:
self._has_regions = False
else:
self._has_regions: bool = any(
[isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()])
self._ignore_label: Union[None, int] = self._determine_ignore_label()
self._all_labels: List[int] = self._get_all_labels()
self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions()
if self.has_ignore_label:
assert self.ignore_label == max(
self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \
'label value! It cannot be 0 or in between other labels. ' \
'Sorry bro.'
if inference_nonlin is None:
self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0
else:
self.inference_nonlin = inference_nonlin
def _sanity_check(self, label_dict: dict):
if not 'background' in label_dict.keys():
raise RuntimeError('Background label not declared (remember that this should be label 0!)')
bg_label = label_dict['background']
if isinstance(bg_label, (tuple, list)):
raise RuntimeError(f"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}")
assert int(bg_label) == 0, f"Background label must be 0. Your background label: {bg_label}"
# not sure if we want to allow regions that contain background. I don't immediately see how this could cause
# problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this
# just crashes.
def _get_all_labels(self) -> List[int]:
all_labels = []
for k, r in self.label_dict.items():
# ignore label is not going to be used, hence the name. Duh.
if k == 'ignore':
continue
if isinstance(r, (tuple, list)):
for ri in r:
all_labels.append(int(ri))
else:
all_labels.append(int(r))
all_labels = list(np.unique(all_labels))
all_labels.sort()
return all_labels
def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:
if not self._has_regions or self._force_use_labels:
return None
else:
assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \
'define regions_class_order!'
regions = []
for k, r in self.label_dict.items():
# ignore ignore label
if k == 'ignore':
continue
# ignore regions that are background
if (np.isscalar(r) and r == 0) \
or \
(isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0):
continue
if isinstance(r, list):
r = tuple(r)
regions.append(r)
assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \
'many entries as there are ' \
'regions'
return regions
def _determine_ignore_label(self) -> Union[None, int]:
ignore_label = self.label_dict.get('ignore')
if ignore_label is not None:
assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \
f'(list/tuple). Got {type(ignore_label)}.'
return ignore_label
@property
def has_regions(self) -> bool:
return self._has_regions
@property
def has_ignore_label(self) -> bool:
return self.ignore_label is not None
@property
def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:
return self._regions
@property
def all_labels(self) -> List[int]:
return self._all_labels
@property
def ignore_label(self) -> Union[None, int]:
return self._ignore_label
def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \
Union[np.ndarray, torch.Tensor]:
"""
logits has to have shape (c, x, y(, z)) where c is the number of classes/regions
"""
if isinstance(logits, np.ndarray):
logits = torch.from_numpy(logits)
with torch.no_grad():
# softmax etc is not implemented for half
logits = logits.float()
probabilities = self.inference_nonlin(logits)
return probabilities
@torch.inference_mode()
def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \
Union[np.ndarray, torch.Tensor]:
"""
assumes that inference_nonlinearity was already applied!
predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions
"""
if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)):
raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor,"
f" got {type(predicted_probabilities)}")
if self.has_regions:
assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \
'define regions_class_order!'
# check correct number of outputs
assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \
f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \
f'got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape ' \
f'(c, x, y(, z)).'
if self.has_regions:
if isinstance(predicted_probabilities, np.ndarray):
segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16)
else:
# no uint16 in torch
segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16,
device=predicted_probabilities.device)
for i, c in enumerate(self.regions_class_order):
segmentation[predicted_probabilities[i] > 0.5] = c
else:
# numpy is faster than torch. :facepalm:
is_numpy = isinstance(predicted_probabilities, np.ndarray)
if not is_numpy:
predicted_probabilities = predicted_probabilities.numpy()
segmentation = predicted_probabilities.argmax(0)
if not is_numpy:
segmentation = torch.from_numpy(segmentation)
return segmentation
@torch.inference_mode()
def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \
Union[np.ndarray, torch.Tensor]:
input_is_numpy = isinstance(predicted_logits, np.ndarray)
# we can skip this step if we do not have region. Argmax is the same between logits or probabilities
if self.has_regions:
probabilities = self.apply_inference_nonlin(predicted_logits)
else:
probabilities = predicted_logits
if input_is_numpy and isinstance(probabilities, torch.Tensor):
probabilities = probabilities.cpu().numpy()
return self.convert_probabilities_to_segmentation(probabilities)
def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray],
bbox: List[List[int]],
original_shape: Union[List[int], Tuple[int, ...]]):
"""
ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!!
predicted_probabilities must be (c, x, y(, z))
Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation
correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities
and not have strange artifacts.
Only LabelManager knows how this needs to be done. So let's let him/her do it, ok?
"""
# revert cropping
probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape),
dtype=predicted_probabilities.dtype) \
if isinstance(predicted_probabilities, np.ndarray) else \
torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype)
if not self.has_regions:
probs_reverted_cropping[0] = 1
probs_reverted_cropping = insert_crop_into_image(probs_reverted_cropping, predicted_probabilities, bbox)
return probs_reverted_cropping
@staticmethod
def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]):
# heck yeah
# This is definitely taking list comprehension too far. Enjoy.
return [i for i in classes_or_regions if
((not isinstance(i, (tuple, list))) and i != 0)
or
(isinstance(i, (tuple, list)) and not (
len(np.unique(i)) == 1 and np.unique(i)[0] == 0))]
@property
def foreground_regions(self):
return self.filter_background(self.all_regions)
@property
def foreground_labels(self):
return self.filter_background(self.all_labels)
@property
def num_segmentation_heads(self):
if self.has_regions:
return len(self.foreground_regions)
else:
return len(self.all_labels)
def get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]:
if 'label_manager' not in plans.keys():
print('No label manager specified in plans. Using default: LabelManager')
return LabelManager
else:
labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], "utilities", "label_handling"),
plans['label_manager'],
current_module="nnunetv2.utilities.label_handling")
return labelmanager_class
def convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor],
all_labels: Union[List, torch.Tensor, np.ndarray, tuple],
output_dtype=None) -> Union[np.ndarray, torch.Tensor]:
"""
if output_dtype is None then we use np.uint8/torch.uint8
if input is torch.Tensor then output will be on the same device
np.ndarray is faster than torch.Tensor
if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have
to cast which takes time.
IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ...
DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo)
"""
if isinstance(segmentation, torch.Tensor):
result = torch.zeros((len(all_labels), *segmentation.shape),
dtype=output_dtype if output_dtype is not None else (torch.uint8 if max(all_labels) < 255 else torch.uint16),
device=segmentation.device)
# variant 1, 2x faster than 2
result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!?
# variant 2, slower than 1
# for i, l in enumerate(all_labels):
# result[i] = segmentation == l
else:
result = np.zeros((len(all_labels), *segmentation.shape),
dtype=output_dtype if output_dtype is not None else (np.uint8 if max(all_labels) < 255 else np.uint16))
# variant 1, fastest in my testing
for i, l in enumerate(all_labels):
result[i] = segmentation == l
# variant 2. Takes about twice as long so nah
# result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2))
return result
def determine_num_input_channels(plans_manager: PlansManager,
configuration_or_config_manager: Union[str, ConfigurationManager],
dataset_json: dict) -> int:
if isinstance(configuration_or_config_manager, str):
config_manager = plans_manager.get_configuration(configuration_or_config_manager)
else:
config_manager = configuration_or_config_manager
label_manager = plans_manager.get_label_manager(dataset_json)
num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names'])
# cascade has different number of input channels
if config_manager.previous_stage_name is not None:
num_label_inputs = len(label_manager.foreground_labels)
num_input_channels = num_modalities + num_label_inputs
else:
num_input_channels = num_modalities
return num_input_channels
if __name__ == '__main__':
# this code used to be able to differentiate variant 1 and 2 to measure time.
num_labels = 7
seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8)
seg_torch = torch.from_numpy(seg)
st = time()
onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels))
time_1 = time()
onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels))
time_2 = time()
onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))
time_torch = time()
onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))
time_torch2 = time()
print(
f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}')
onehot_torch = onehot_torch.numpy()
onehot_torch2 = onehot_torch2.numpy()
print(np.all(onehot_torch == onehot_npy))
print(np.all(onehot_torch2 == onehot_npy))
================================================
FILE: nnunetv2/utilities/network_initialization.py
================================================
from torch import nn
class InitWeights_He(object):
def __init__(self, neg_slope=1e-2):
self.neg_slope = neg_slope
def __call__(self, module):
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
if module.bias is not None:
module.bias = nn.init.constant_(module.bias, 0)
================================================
FILE: nnunetv2/utilities/overlay_plots.py
================================================
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from multiprocessing.pool import Pool
from typing import Tuple, Union
import numpy as np
import pandas as pd
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.training.dataloading.nnunet_dataset import infer_dataset_class, nnUNetBaseDataset
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
get_filenames_of_train_images_and_targets
color_cycle = (
"000000",
"4363d8",
"f58231",
"3cb44b",
"e6194B",
"911eb4",
"ffe119",
"bfef45",
"42d4f4",
"f032e6",
"000075",
"9A6324",
"808000",
"800000",
"469990",
)
def hex_to_rgb(hex: str):
assert len(hex) == 6
return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))
def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None,
color_cycle: Tuple[str, ...] = color_cycle,
overlay_intensity: float = 0.6):
"""
image can be 2d greyscale or 2d RGB (color channel in last dimension!)
Segmentation must be label map of same shape as image (w/o color channels)
mapping can be label_id -> idx_in_cycle or None
returned image is scaled to [0, 255] (uint8)!!!
"""
# create a copy of image
image = np.copy(input_image)
if image.ndim == 2:
image = np.tile(image[:, :, None], (1, 1, 3))
elif image.ndim == 3:
if image.shape[2] == 1:
image = np.tile(image, (1, 1, 3))
else:
raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). '
f'Only 2D images are supported. Your image shape: {image.shape}')
else:
raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in "
"last dimension) are supported")
# rescale image to [0, 255]
image = image - image.min()
image = image / image.max() * 255
# create output
if mapping is None:
uniques = np.sort(pd.unique(segmentation.ravel())) # np.unique(segmentation)
mapping = {i: c for c, i in enumerate(uniques)}
for l in mapping.keys():
image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))
# rescale result to [0, 255]
image = image / image.max() * 255
return image.astype(np.uint8)
def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int:
"""
image and segmentation are expected to be 3D
selects the slice with the largest amount of fg (regardless of label)
we give image so that we can easily replace this function if needed
"""
fg_mask = segmentation != 0
fg_per_slice = fg_mask.sum((1, 2))
selected_slice = int(np.argmax(fg_per_slice))
return selected_slice
def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int:
"""
image and segmentation are expected to be 3D (or 1, x, y)
selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice
with highest avg percent)
we give image so that we can easily replace this function if needed
"""
classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i > 0]
fg_per_slice = np.zeros((image.shape[0], len(classes)))
for i, c in enumerate(classes):
fg_mask = segmentation == c
fg_per_slice[:, i] = fg_mask.sum((1, 2))
fg_per_slice[:, i] /= fg_per_slice.sum()
fg_per_slice = fg_per_slice.mean(1)
return int(np.argmax(fg_per_slice))
def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str,
overlay_intensity: float = 0.6):
import matplotlib.pyplot as plt
image, props = image_reader_writer.read_images((image_file, ))
image = image[0]
seg, props_seg = image_reader_writer.read_seg(segmentation_file)
seg = seg[0]
assert image.shape == seg.shape, "image and seg do not have the same shape: %s, %s" % (
image_file, segmentation_file)
assert image.ndim == 3, 'only 3D images/segs are supported'
selected_slice = select_slice_to_plot2(image, seg)
# print(image.shape, selected_slice)
overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
plt.imsave(output_file, overlay)
def plot_overlay_preprocessed(dataset: nnUNetBaseDataset, k: str, output_folder: str, overlay_intensity: float = 0.6, channel_idx=0):
import matplotlib.pyplot as plt
data, seg, _, properties = dataset.load_case(k)
assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1)
image = data[channel_idx]
seg = seg[0]
selected_slice = select_slice_to_plot2(image, seg)
seg = np.copy(seg[selected_slice])
seg[seg < 0] = 0
overlay = generate_overlay(image[selected_slice], seg, overlay_intensity=overlay_intensity)
plt.imsave(join(output_folder, k + '.png'), overlay)
def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer,
list_of_output_files, overlay_intensity,
num_processes=8):
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
r = p.starmap_async(plot_overlay, zip(
list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files),
list_of_output_files, [overlay_intensity] * len(list_of_output_files)
))
r.get()
def multiprocessing_plot_overlay_preprocessed(dataset: nnUNetBaseDataset, output_folder, overlay_intensity,
num_processes=8, channel_idx=0):
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
r = []
for k in dataset.identifiers:
r.append(
p.starmap_async(plot_overlay_preprocessed,
((
dataset, k, output_folder, overlay_intensity, channel_idx
),))
)
_ = [i.get() for i in r]
def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str,
num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6):
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
folder = join(nnUNet_raw, dataset_name)
dataset_json = load_json(join(folder, 'dataset.json'))
dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)
image_files = [v['images'][channel_idx] for v in dataset.values()]
seg_files = [v['label'] for v in dataset.values()]
assert all([isfile(i) for i in image_files])
assert all([isfile(i) for i in seg_files])
maybe_mkdir_p(output_folder)
output_files = [join(output_folder, i + '.png') for i in dataset.keys()]
image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])()
multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes)
def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str,
num_processes: int = 8, channel_idx: int = 0,
configuration: str = None,
plans_identifier: str = 'nnUNetPlans',
overlay_intensity: float = 0.6):
dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
folder = join(nnUNet_preprocessed, dataset_name)
if not isdir(folder): raise RuntimeError("run preprocessing for that task first")
plans = load_json(join(folder, plans_identifier + '.json'))
if configuration is None:
if '3d_fullres' in plans['configurations'].keys():
configuration = '3d_fullres'
else:
configuration = '2d'
cm = ConfigurationManager(plans['configurations'][configuration])
preprocessed_folder = join(folder, cm.data_identifier)
if not isdir(preprocessed_folder):
raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier "
f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this "
f"configuration first!")
dc = infer_dataset_class(preprocessed_folder)
dataset = dc(preprocessed_folder)
maybe_mkdir_p(output_folder)
multiprocessing_plot_overlay_preprocessed(dataset, output_folder, overlay_intensity=overlay_intensity,
num_processes=num_processes, channel_idx=channel_idx)
def entry_point_generate_overlay():
import argparse
parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this "
"disregards spacing information!")
parser.add_argument('-d', type=str, help="Dataset name or id", required=True)
parser.add_argument('-o', type=str, help="output folder", required=True)
parser.add_argument('-np', type=int, default=default_num_processes, required=False,
help=f"number of processes used. Default: {default_num_processes}")
parser.add_argument('-channel_idx', type=int, default=0, required=False,
help="channel index used (0 = _0000). Default: 0")
parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else "
"we use preprocessed")
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans')
parser.add_argument('-c', type=str, required=False, default=None,
help='configuration name. Only used if --use_raw is not set! Default: None = '
'3d_fullres if available, else 2d')
parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6,
help='overlay intensity. Higher = brighter/less transparent')
args = parser.parse_args()
if args.use_raw:
generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx,
overlay_intensity=args.overlay_intensity)
else:
generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p,
overlay_intensity=args.overlay_intensity)
if __name__ == '__main__':
entry_point_generate_overlay()
================================================
FILE: nnunetv2/utilities/plans_handling/__init__.py
================================================
================================================
FILE: nnunetv2/utilities/plans_handling/plans_handler.py
================================================
from __future__ import annotations
import warnings
from copy import deepcopy
from functools import lru_cache, partial
from typing import Union, Tuple, List, Type, Callable
import numpy as np
import torch
from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name
import nnunetv2
from batchgenerators.utilities.file_and_folder_operations import load_json, join
from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans
# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
from typing import TYPE_CHECKING
from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
if TYPE_CHECKING:
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
class ConfigurationManager(object):
def __init__(self, configuration_dict: dict):
self.configuration = configuration_dict
# backwards compatibility
if 'architecture' not in self.configuration.keys():
warnings.warn("Detected old nnU-Net plans format. Attempting to reconstruct network architecture "
"parameters. If this fails, rerun nnUNetv2_plan_experiment for your dataset. If you use a "
"custom architecture, please downgrade nnU-Net to the version you implemented this "
"or update your implementation + plans.")
# try to build the architecture information from old plans, modify configuration dict to match new standard
unet_class_name = self.configuration["UNet_class_name"]
if unet_class_name == "PlainConvUNet":
network_class_name = "dynamic_network_architectures.architectures.unet.PlainConvUNet"
elif unet_class_name == 'ResidualEncoderUNet':
network_class_name = "dynamic_network_architectures.architectures.residual_unet.ResidualEncoderUNet"
else:
raise RuntimeError(f'Unknown architecture {unet_class_name}. This conversion only supports '
f'PlainConvUNet and ResidualEncoderUNet')
n_stages = len(self.configuration["n_conv_per_stage_encoder"])
dim = len(self.configuration["patch_size"])
conv_op = convert_dim_to_conv_op(dim)
instnorm = get_matching_instancenorm(dimension=dim)
convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage"
arch_dict = {
'network_class_name': network_class_name,
'arch_kwargs': {
"n_stages": n_stages,
"features_per_stage": [min(self.configuration["UNet_base_num_features"] * 2 ** i,
self.configuration["unet_max_num_features"])
for i in range(n_stages)],
"conv_op": conv_op.__module__ + '.' + conv_op.__name__,
"kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]),
"strides": deepcopy(self.configuration["pool_op_kernel_sizes"]),
convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]),
"n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]),
"conv_bias": True,
"norm_op": instnorm.__module__ + '.' + instnorm.__name__,
"norm_op_kwargs": {
"eps": 1e-05,
"affine": True
},
"dropout_op": None,
"dropout_op_kwargs": None,
"nonlin": "torch.nn.LeakyReLU",
"nonlin_kwargs": {
"inplace": True
}
},
# these need to be imported with locate in order to use them:
# `conv_op = pydoc.locate(architecture_kwargs['conv_op'])`
"_kw_requires_import": [
"conv_op",
"norm_op",
"dropout_op",
"nonlin"
]
}
del self.configuration["UNet_class_name"], self.configuration["UNet_base_num_features"], \
self.configuration["n_conv_per_stage_encoder"], self.configuration["n_conv_per_stage_decoder"], \
self.configuration["num_pool_per_axis"], self.configuration["pool_op_kernel_sizes"],\
self.configuration["conv_kernel_sizes"], self.configuration["unet_max_num_features"]
self.configuration["architecture"] = arch_dict
def __repr__(self):
return self.configuration.__repr__()
@property
def data_identifier(self) -> str:
return self.configuration['data_identifier']
@property
def preprocessor_name(self) -> str:
return self.configuration['preprocessor_name']
@property
@lru_cache(maxsize=1)
def preprocessor_class(self) -> Type[DefaultPreprocessor]:
preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"),
self.preprocessor_name,
current_module="nnunetv2.preprocessing")
return preprocessor_class
@property
def batch_size(self) -> int:
return self.configuration['batch_size']
@property
def patch_size(self) -> List[int]:
return self.configuration['patch_size']
@property
def median_image_size_in_voxels(self) -> List[int]:
return self.configuration['median_image_size_in_voxels']
@property
def spacing(self) -> List[float]:
return self.configuration['spacing']
@property
def normalization_schemes(self) -> List[str]:
return self.configuration['normalization_schemes']
@property
def use_mask_for_norm(self) -> List[bool]:
return self.configuration['use_mask_for_norm']
@property
def network_arch_class_name(self) -> str:
return self.configuration['architecture']['network_class_name']
@property
def network_arch_init_kwargs(self) -> dict:
return self.configuration['architecture']['arch_kwargs']
@property
def network_arch_init_kwargs_req_import(self) -> Union[Tuple[str, ...], List[str]]:
return self.configuration['architecture']['_kw_requires_import']
@property
def pool_op_kernel_sizes(self) -> Tuple[Tuple[int, ...], ...]:
return self.configuration['architecture']['arch_kwargs']['strides']
@property
@lru_cache(maxsize=1)
def resampling_fn_data(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray]
],
Union[torch.Tensor, np.ndarray]]:
fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data'])
fn = partial(fn, **self.configuration['resampling_fn_data_kwargs'])
return fn
@property
@lru_cache(maxsize=1)
def resampling_fn_probabilities(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray]
],
Union[torch.Tensor, np.ndarray]]:
fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities'])
fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs'])
return fn
@property
@lru_cache(maxsize=1)
def resampling_fn_seg(self) -> Callable[
[Union[torch.Tensor, np.ndarray],
Union[Tuple[int, ...], List[int], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray],
Union[Tuple[float, ...], List[float], np.ndarray]
],
Union[torch.Tensor, np.ndarray]]:
fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg'])
fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs'])
return fn
@property
def batch_dice(self) -> bool:
return self.configuration['batch_dice']
@property
def next_stage_names(self) -> Union[List[str], None]:
ret = self.configuration.get('next_stage')
if ret is not None:
if isinstance(ret, str):
ret = [ret]
return ret
@property
def previous_stage_name(self) -> Union[str, None]:
return self.configuration.get('previous_stage')
class PlansManager(object):
def __init__(self, plans_file_or_dict: Union[str, dict]):
"""
Why do we need this?
1) resolve inheritance in configurations
2) expose otherwise annoying stuff like getting the label manager or IO class from a string
3) clearly expose the things that are in the plans instead of hiding them in a dict
4) cache shit
This class does not prevent you from going wild. You can still use the plans directly if you prefer
(PlansHandler.plans['key'])
"""
self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict)
def __repr__(self):
return self.plans.__repr__()
def _internal_resolve_configuration_inheritance(self, configuration_name: str,
visited: Tuple[str, ...] = None) -> dict:
if configuration_name not in self.plans['configurations'].keys():
raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid '
f'configuration names are {list(self.plans["configurations"].keys())}.')
configuration = deepcopy(self.plans['configurations'][configuration_name])
if 'inherits_from' in configuration:
parent_config_name = configuration['inherits_from']
if visited is None:
visited = (configuration_name,)
else:
if parent_config_name in visited:
raise RuntimeError(f"Circular dependency detected. The following configurations were visited "
f"while solving inheritance (in that order!): {visited}. "
f"Current configuration: {configuration_name}. Its parent configuration "
f"is {parent_config_name}.")
visited = (*visited, configuration_name)
base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited)
base_config.update(configuration)
configuration = base_config
return configuration
@lru_cache(maxsize=10)
def get_configuration(self, configuration_name: str):
if configuration_name not in self.plans['configurations'].keys():
raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. "
f"Available configurations: {list(self.plans['configurations'].keys())}")
configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name)
return ConfigurationManager(configuration_dict)
@property
def dataset_name(self) -> str:
return self.plans['dataset_name']
@property
def plans_name(self) -> str:
return self.plans['plans_name']
@property
def original_median_spacing_after_transp(self) -> List[float]:
return self.plans['original_median_spacing_after_transp']
@property
def original_median_shape_after_transp(self) -> List[float]:
return self.plans['original_median_shape_after_transp']
@property
@lru_cache(maxsize=1)
def image_reader_writer_class(self) -> Type[BaseReaderWriter]:
return recursive_find_reader_writer_by_name(self.plans['image_reader_writer'])
@property
def transpose_forward(self) -> List[int]:
return self.plans['transpose_forward']
@property
def transpose_backward(self) -> List[int]:
return self.plans['transpose_backward']
@property
def available_configurations(self) -> List[str]:
return list(self.plans['configurations'].keys())
@property
@lru_cache(maxsize=1)
def experiment_planner_class(self) -> Type[ExperimentPlanner]:
planner_name = self.experiment_planner_name
experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
planner_name,
current_module="nnunetv2.experiment_planning")
return experiment_planner
@property
def experiment_planner_name(self) -> str:
return self.plans['experiment_planner_used']
@property
@lru_cache(maxsize=1)
def label_manager_class(self) -> Type[LabelManager]:
return get_labelmanager_class_from_plans(self.plans)
def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager:
return self.label_manager_class(label_dict=dataset_json['labels'],
regions_class_order=dataset_json.get('regions_class_order'),
**kwargs)
@property
def foreground_intensity_properties_per_channel(self) -> dict:
if 'foreground_intensity_properties_per_channel' not in self.plans.keys():
if 'foreground_intensity_properties_by_modality' in self.plans.keys():
return self.plans['foreground_intensity_properties_by_modality']
return self.plans['foreground_intensity_properties_per_channel']
if __name__ == '__main__':
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json'))
# build new configuration that inherits from 3d_fullres
plans['configurations']['3d_fullres_bs4'] = {
'batch_size': 4,
'inherits_from': '3d_fullres'
}
# now get plans and configuration managers
plans_manager = PlansManager(plans)
configuration_manager = plans_manager.get_configuration('3d_fullres_bs4')
print(configuration_manager) # look for batch size 4
================================================
FILE: nnunetv2/utilities/utils.py
================================================
# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
# (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from functools import lru_cache
from typing import Union
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
import re
from nnunetv2.paths import nnUNet_raw
from multiprocessing import Pool
def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str):
files = subfiles(folder, suffix=file_ending, join=False)
# all files have a 4 digit channel index (_XXXX)
crop = len(file_ending) + 5
files = [i[:-crop] for i in files]
# only unique image ids
files = np.unique(files)
return files
def create_paths_fn(folder, files, file_ending, f):
p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending))
return [join(folder, i) for i in files if p.fullmatch(i)]
def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None, num_processes: int = 12) -> List[
List[str]]:
"""
does not rely on dataset.json
"""
if identifiers is None:
identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending)
files = subfiles(folder, suffix=file_ending, join=False, sort=True)
list_of_lists = []
params_list = [(folder, files, file_ending, f) for f in identifiers]
with Pool(processes=num_processes) as pool:
list_of_lists = pool.starmap(create_paths_fn, params_list)
return list_of_lists
def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None):
if dataset_json is None:
dataset_json = load_json(join(raw_dataset_folder, 'dataset.json'))
if 'dataset' in dataset_json.keys():
dataset = dataset_json['dataset']
for k in dataset.keys():
expanded_label_file = os.path.expandvars(dataset[k]['label'])
dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, expanded_label_file)) if not os.path.isabs(expanded_label_file) else expanded_label_file
dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, os.path.expandvars(i))) if not os.path.isabs(os.path.expandvars(i)) else os.path.expandvars(i) for i in dataset[k]['images']]
else:
identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'])
images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers)
segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers]
dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)}
return dataset
if __name__ == '__main__':
print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart')))
================================================
FILE: pyproject.toml
================================================
[project]
name = "nnunetv2"
version = "2.6.4"
requires-python = ">=3.10"
description = "nnU-Net is a framework for out-of-the box image segmentation."
readme = "readme.md"
license = { file = "LICENSE" }
authors = [
{ name = "Fabian Isensee", email = "f.isensee@dkfz-heidelberg.de"},
{ name = "Helmholtz Imaging Applied Computer Vision Lab" }
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: Healthcare Industry",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
]
keywords = [
'deep learning',
'image segmentation',
'semantic segmentation',
'medical image analysis',
'medical image segmentation',
'nnU-Net',
'nnunet'
]
dependencies = [
"torch>=2.1.2,!=2.9.*",
"acvl-utils>=0.2.3,<0.3", # 0.3 may bring breaking changes. Careful!
"dynamic-network-architectures>=0.4.1,<0.5",
"tqdm",
"scipy",
"batchgenerators>=0.25.1",
"numpy>=1.24",
"scikit-learn",
"scikit-image>=0.19.3",
"SimpleITK>=2.2.1",
"pandas",
"graphviz",
'tifffile',
'requests',
"nibabel",
"matplotlib",
"seaborn",
"imagecodecs",
"yacs",
"batchgeneratorsv2>=0.3.0",
"einops",
"blosc2>=3.0.0b1"
]
[project.urls]
homepage = "https://github.com/MIC-DKFZ/nnUNet"
repository = "https://github.com/MIC-DKFZ/nnUNet"
[project.scripts]
nnUNetv2_plan_and_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry"
nnUNetv2_extract_fingerprint = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry"
nnUNetv2_plan_experiment = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry"
nnUNetv2_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry"
nnUNetv2_train = "nnunetv2.run.run_training:run_training_entry"
nnUNetv2_predict_from_modelfolder = "nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder"
nnUNetv2_predict = "nnunetv2.inference.predict_from_raw_data:predict_entry_point"
nnUNetv2_convert_old_nnUNet_dataset = "nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point"
nnUNetv2_find_best_configuration = "nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point"
nnUNetv2_determine_postprocessing = "nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder"
nnUNetv2_apply_postprocessing = "nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing"
nnUNetv2_ensemble = "nnunetv2.ensembling.ensemble:entry_point_ensemble_folders"
nnUNetv2_accumulate_crossval_results = "nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point"
nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_generate_overlay"
nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url"
nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point"
nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry"
nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets"
nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point"
nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point"
nnUNetv2_convert_MSD_dataset = "nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point"
[project.optional-dependencies]
dev = [
"black",
"ruff",
"pre-commit"
]
[build-system]
requires = ["setuptools>=67.8.0"]
build-backend = "setuptools.build_meta"
[tool.codespell]
skip = '.git,*.pdf,*.svg'
#
# ignore-words-list = ''
================================================
FILE: readme.md
================================================
# Welcome to the new nnU-Net!
Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead.
Coming from V1? Check out the [TLDR Migration Guide](documentation/tldr_migration_guide_from_v1.md). Reading the rest of the documentation is still strongly recommended ;-)
## **2025-10-23 There seems to be a [severe performance regression with torch 2.9.0 and 3D convs](https://github.com/pytorch/pytorch/issues/166122) (when using AMP). Please use torch 2.8.0 or lower with nnU-Net!**
## **2024-04-18 UPDATE: New residual encoder UNet presets available!**
Residual encoder UNet presets substantially improve segmentation performance.
They ship for a variety of GPU memory targets. It's all awesome stuff, promised!
Read more :point_right: [here](documentation/resenc_presets.md) :point_left:
Also check out our [new paper](https://arxiv.org/pdf/2404.09556.pdf) on systematically benchmarking recent developments in medical image segmentation. You might be surprised!
# What is nnU-Net?
Image datasets are enormously diverse: image dimensionality (2D, 3D), modalities/input channels (RGB image, CT, MRI, microscopy, ...),
image sizes, voxel sizes, class ratio, target structure properties and more change substantially between datasets.
Traditionally, given a new problem, a tailored solution needs to be manually designed and optimized - a process that
is prone to errors, not scalable and where success is overwhelmingly determined by the skill of the experimenter. Even
for experts, this process is anything but simple: there are not only many design choices and data properties that need to
be considered, but they are also tightly interconnected, rendering reliable manual pipeline optimization all but impossible!

**nnU-Net is a semantic segmentation method that automatically adapts to a given dataset. It will analyze the provided
training cases and automatically configure a matching U-Net-based segmentation pipeline. No expertise required on your
end! You can simply train the models and use them for your application**.
Upon release, nnU-Net was evaluated on 23 datasets belonging to competitions from the biomedical domain. Despite competing
with handcrafted solutions for each respective dataset, nnU-Net's fully automated pipeline scored several first places on
open leaderboards! Since then nnU-Net has stood the test of time: it continues to be used as a baseline and method
development framework ([9 out of 10 challenge winners at MICCAI 2020](https://arxiv.org/abs/2101.00232) and 5 out of 7
in MICCAI 2021 built their methods on top of nnU-Net,
[we won AMOS2022 with nnU-Net](https://amos22.grand-challenge.org/final-ranking/))!
Please cite the [following paper](https://www.google.com/url?q=https://www.nature.com/articles/s41592-020-01008-z&sa=D&source=docs&ust=1677235958581755&usg=AOvVaw3dWL0SrITLhCJUBiNIHCQO) when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring
method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
## What can nnU-Net do for you?
If you are a **domain scientist** (biologist, radiologist, ...) looking to analyze your own images, nnU-Net provides
an out-of-the-box solution that is all but guaranteed to provide excellent results on your individual dataset. Simply
convert your dataset into the nnU-Net format and enjoy the power of AI - no expertise required!
If you are an **AI researcher** developing segmentation methods, nnU-Net:
- offers a fantastic out-of-the-box applicable baseline algorithm to compete against
- can act as a method development framework to test your contribution on a large number of datasets without having to
tune individual pipelines (for example evaluating a new loss function)
- provides a strong starting point for further dataset-specific optimizations. This is particularly used when competing
in segmentation challenges
- provides a new perspective on the design of segmentation methods: maybe you can find better connections between
dataset properties and best-fitting segmentation pipelines?
## What is the scope of nnU-Net?
nnU-Net is built for semantic segmentation. It can handle 2D and 3D images with arbitrary
input modalities/channels. It can understand voxel spacings, anisotropies and is robust even when classes are highly
imbalanced.
nnU-Net relies on supervised learning, which means that you need to provide training cases for your application. The number of
required training cases varies heavily depending on the complexity of the segmentation problem. No
one-fits-all number can be provided here! nnU-Net does not require more training cases than other solutions - maybe
even less due to our extensive use of data augmentation.
nnU-Net expects to be able to process entire images at once during preprocessing and postprocessing, so it cannot
handle enormous images. As a reference: we tested images from 40x40x40 pixels all the way up to 1500x1500x1500 in 3D
and 40x40 up to ~30000x30000 in 2D! If your RAM allows it, larger is always possible.
## How does nnU-Net work?
Given a new dataset, nnU-Net will systematically analyze the provided training cases and create a 'dataset fingerprint'.
nnU-Net then creates several U-Net configurations for each dataset:
- `2d`: a 2D U-Net (for 2D and 3D datasets)
- `3d_fullres`: a 3D U-Net that operates on a high image resolution (for 3D datasets only)
- `3d_lowres` → `3d_cascade_fullres`: a 3D U-Net cascade where first a 3D U-Net operates on low resolution images and
then a second high-resolution 3D U-Net refined the predictions of the former (for 3D datasets with large image sizes only)
**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the
U-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full
resolution U-Net already covers a large part of the input images.**
nnU-Net configures its segmentation pipelines based on a three-step recipe:
- **Fixed parameters** are not adapted. During development of nnU-Net we identified a robust configuration (that is, certain architecture and training properties) that can
simply be used all the time. This includes, for example, nnU-Net's loss function, (most of the) data augmentation strategy and learning rate.
- **Rule-based parameters** use the dataset fingerprint to adapt certain segmentation pipeline properties by following
hard-coded heuristic rules. For example, the network topology (pooling behavior and depth of the network architecture)
are adapted to the patch size; the patch size, network topology and batch size are optimized jointly given some GPU
memory constraint.
- **Empirical parameters** are essentially trial-and-error. For example the selection of the best U-net configuration
for the given dataset (2D, 3D full resolution, 3D low resolution, 3D cascade) and the optimization of the postprocessing strategy.
## How to get started?
Read these:
- [Installation instructions](documentation/installation_instructions.md)
- [Dataset conversion](documentation/dataset_format.md)
- [Usage instructions](documentation/how_to_use_nnunet.md)
Additional information:
- [Contributing to nnU-Net](CONTRIBUTING.md)
- [Learning from sparse annotations (scribbles, slices)](documentation/ignore_label.md)
- [Region-based training](documentation/region_based_training.md)
- [Manual data splits](documentation/manual_data_splits.md)
- [Pretraining and finetuning](documentation/pretraining_and_finetuning.md)
- [Intensity Normalization in nnU-Net](documentation/explanation_normalization.md)
- [Training logging (Local + Weights & Biases)](documentation/explanation_logging.md)
- [Manually editing nnU-Net configurations](documentation/explanation_plans_files.md)
- [Extending nnU-Net](documentation/extending_nnunet.md)
- [What is different in V2?](documentation/changelog.md)
Competitions:
- [AutoPET II](documentation/competitions/AutoPETII.md)
[//]: # (- [Ignore label](documentation/ignore_label.md))
## Where does nnU-Net perform well and where does it not perform?
nnU-Net excels in segmentation problems that need to be solved by training from scratch,
for example: research applications that feature non-standard image modalities and input channels,
challenge datasets from the biomedical domain, majority of 3D segmentation problems, etc . We have yet to find a
dataset for which nnU-Net's working principle fails!
Note: On standard segmentation
problems, such as 2D RGB images in ADE20k and Cityscapes, fine-tuning a foundation model (that was pretrained on a large corpus of
similar images, e.g. Imagenet 22k, JFT-300M) will provide better performance than nnU-Net! That is simply because these
models allow much better initialization. Foundation models are not supported by nnU-Net as
they 1) are not useful for segmentation problems that deviate from the standard setting (see above mentioned
datasets), 2) would typically only support 2D architectures and 3) conflict with our core design principle of carefully adapting
the network topology for each dataset (if the topology is changed one can no longer transfer pretrained weights!)
## What happened to the old nnU-Net?
The core of the old nnU-Net was hacked together in a short time period while participating in the Medical Segmentation
Decathlon challenge in 2018. Consequently, code structure and quality were not the best. Many features
were added later on and didn't quite fit into the nnU-Net design principles. Overall quite messy, really. And annoying to work with.
nnU-Net V2 is a complete overhaul. The "delete everything and start again" kind. So everything is better
(in the author's opinion haha). While the segmentation performance [remains the same](https://docs.google.com/spreadsheets/d/13gqjIKEMPFPyMMMwA1EML57IyoBjfC3-QCTn4zRN_Mg/edit?usp=sharing), a lot of cool stuff has been added.
It is now also much easier to use it as a development framework and to manually fine-tune its configuration to new
datasets. A big driver for the reimplementation was also the emergence of [Helmholtz Imaging](http://helmholtz-imaging.de),
prompting us to extend nnU-Net to more image formats and domains. Take a look [here](documentation/changelog.md) for some highlights.
# Acknowledgements
nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of [Helmholtz Imaging](http://helmholtz-imaging.de)
and the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) at the
[German Cancer Research Center (DKFZ)](https://www.dkfz.de/en/index.html).
================================================
FILE: setup.py
================================================
import setuptools
if __name__ == "__main__":
setuptools.setup()